diff --git a/pytential/qbx/fmmlib.py b/pytential/qbx/fmmlib.py index f0021cce4e332a0c7239d4eacaa61f50858f6251..66618e4f01bac69952011df1934228c370a68417 100644 --- a/pytential/qbx/fmmlib.py +++ b/pytential/qbx/fmmlib.py @@ -292,133 +292,49 @@ class QBXFMMLibExpansionWrangler(FMMLibExpansionWrangler): # {{{ p2qbxl - @memoize_method - def _info_for_form_global_qbx_locals(self): - logger.info("preparing interaction list for p2qbxl: start") - - geo_data = self.geo_data - traversal = geo_data.traversal() - - starts = traversal.neighbor_source_boxes_starts - lists = traversal.neighbor_source_boxes_lists - - qbx_center_to_target_box = geo_data.qbx_center_to_target_box() - qbx_centers = geo_data.centers() - - center_source_counts = [0] - for itgt_center, tgt_icenter in enumerate(geo_data.global_qbx_centers()): - itgt_box = qbx_center_to_target_box[tgt_icenter] - - isrc_box_start = starts[itgt_box] - isrc_box_stop = starts[itgt_box+1] - - source_count = sum( - self.tree.box_source_counts_nonchild[lists[isrc_box]] - for isrc_box in range(isrc_box_start, isrc_box_stop)) - - center_source_counts.append(source_count) - - center_source_counts = np.array(center_source_counts) - center_source_starts = np.cumsum(center_source_counts) - nsources_total = center_source_starts[-1] - center_source_offsets = np.empty(nsources_total, np.int32) - - isource = 0 - for itgt_center, tgt_icenter in enumerate(geo_data.global_qbx_centers()): - assert isource == center_source_starts[itgt_center] - itgt_box = qbx_center_to_target_box[tgt_icenter] - - isrc_box_start = starts[itgt_box] - isrc_box_stop = starts[itgt_box+1] - - for isrc_box in range(isrc_box_start, isrc_box_stop): - src_ibox = lists[isrc_box] - - src_pslice = self._get_source_slice(src_ibox) - ns = self.tree.box_source_counts_nonchild[src_ibox] - center_source_offsets[isource:isource+ns] = np.arange( - src_pslice.start, src_pslice.stop) - - isource += ns - - centers = qbx_centers[:, geo_data.global_qbx_centers()] - rscale_vec = geo_data.expansion_radii()[geo_data.global_qbx_centers()] - - nsources_vec = np.ones(self.tree.nsources, np.int32) - - logger.info("preparing interaction list for p2qbxl: done") - - return P2QBXLInfo( - centers=centers, - center_source_starts=center_source_starts, - center_source_offsets=center_source_offsets, - nsources_vec=nsources_vec, - rscale_vec=rscale_vec, - ngqbx_centers=centers.shape[1], - ) - def form_global_qbx_locals(self, src_weights): geo_data = self.geo_data - - qbx_exps = self.qbx_local_expansion_zeros() + trav = geo_data.traversal() if len(geo_data.global_qbx_centers()) == 0: - return qbx_exps + return self.qbx_local_expansion_zeros() - formta_imany = self.get_routine("%ddformta" + self.dp_suffix, - suffix="_imany") - info = self._info_for_form_global_qbx_locals() + formta_qbx = self.get_routine("%ddformta" + self.dp_suffix, + suffix="_qbx") kwargs = {} kwargs.update(self.kernel_kwargs) if self.dipole_vec is None: kwargs["charge"] = src_weights - kwargs["charge_offsets"] = info.center_source_offsets - kwargs["charge_starts"] = info.center_source_starts else: - kwargs["dipstr_offsets"] = info.center_source_offsets - kwargs["dipstr_starts"] = info.center_source_starts - if self.dim == 2 and self.eqn_letter == "l": kwargs["dipstr"] = -src_weights * ( self.dipole_vec[0] + 1j*self.dipole_vec[1]) else: kwargs["dipstr"] = src_weights - kwargs["dipvec"] = self.dipole_vec - kwargs["dipvec_offsets"] = info.center_source_offsets - kwargs["dipvec_starts"] = info.center_source_starts - - # These get max'd/added onto: pass initialized versions. - ier = np.zeros(info.ngqbx_centers, dtype=np.int32) - expn = np.zeros( - (info.ngqbx_centers,) + self.expansion_shape(self.qbx_order), - dtype=self.dtype) - - ier, expn = formta_imany( - rscale=info.rscale_vec, + ier, qbx_exps = formta_qbx( sources=self._get_single_sources_array(), - sources_offsets=info.center_source_offsets, - sources_starts=info.center_source_starts, - - nsources=info.nsources_vec, - nsources_offsets=info.center_source_offsets, - nsources_starts=info.center_source_starts, - - center=info.centers, - - ier=ier, - expn=expn.T, - + qbx_centers=geo_data.centers().T, + global_qbx_centers=geo_data.global_qbx_centers(), + qbx_expansion_radii=geo_data.expansion_radii(), + qbx_center_to_target_box=geo_data.qbx_center_to_target_box(), + nterms=self.qbx_order, + source_box_starts=trav.neighbor_source_boxes_starts, + source_box_lists=trav.neighbor_source_boxes_lists, + box_source_starts=self.tree.box_source_starts, + box_source_counts_nonchild=self.tree.box_source_counts_nonchild, **kwargs) + qbx_exps = qbx_exps.T if np.any(ier != 0): - raise RuntimeError("formta returned an error") + raise RuntimeError("formta for p2qbxl returned an error (ier=%d)" % ier) - qbx_exps[geo_data.global_qbx_centers()] = expn.T + qbx_exps_2 = self.qbx_local_expansion_zeros() + assert qbx_exps.shape == qbx_exps_2.shape return qbx_exps