diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index 3568601ff3e02c81cb78839663cd386657f54e81..adb858a119921812cf0c9b7084a2cedfb0397856 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -575,9 +575,6 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): # }}} - if len(geo_data.global_qbx_centers()) != geo_data.ncenters: - raise NotImplementedError("geometry has centers requiring local QBX") - from pytential.qbx.geometry import target_state if (geo_data.user_target_to_center().with_queue(queue) == target_state.FAILED).any().get(): diff --git a/pytential/qbx/fmmlib.py b/pytential/qbx/fmmlib.py index 9ec0194cbd1c3d8494d484f5c14063ca46d32d56..bbb37ab0703d99ed352acda54bb52737d184ebdb 100644 --- a/pytential/qbx/fmmlib.py +++ b/pytential/qbx/fmmlib.py @@ -419,6 +419,10 @@ class QBXFMMLibExpansionWrangler(FMMLibExpansionWrangler): qbx_center_to_target_box = geo_data.qbx_center_to_target_box() qbx_centers = geo_data.centers() centers = self.tree.box_centers + ngqbx_centers = len(geo_data.global_qbx_centers()) + + if ngqbx_centers == 0: + return local_exps mploc = self.get_translation_routine("%ddmploc", vec_suffix="_imany") @@ -429,7 +433,6 @@ class QBXFMMLibExpansionWrangler(FMMLibExpansionWrangler): print("par data prep lev %d" % isrc_level) - ngqbx_centers = len(geo_data.global_qbx_centers()) tgt_icenter_vec = geo_data.global_qbx_centers() icontaining_tgt_box_vec = qbx_center_to_target_box[tgt_icenter_vec] diff --git a/pytential/qbx/geometry.py b/pytential/qbx/geometry.py index af8da4ca928524fb0085c7744331981118806568..96bacdf18deeadcdd0db9e2131883a52ccff69a4 100644 --- a/pytential/qbx/geometry.py +++ b/pytential/qbx/geometry.py @@ -222,6 +222,28 @@ class QBXFMMGeometryCodeGetter(object): if (i+1 == N) *count = item; """) + @property + @memoize_method + def pick_used_centers(self): + knl = lp.make_kernel( + """{[i]: 0<=itarget_has_center = (target_to_center[i] >= 0) + center_is_used[target_to_center[i]] = 1 \ + {id=center_is_used_write,if=target_has_center} + """, + [ + lp.GlobalArg("target_to_center", shape="ntargets", offset=lp.auto), + lp.GlobalArg("center_is_used", shape="ncenters"), + lp.ValueArg("ncenters", np.int32), + lp.ValueArg("ntargets", np.int32), + ], + name="pick_used_centers", + silenced_warnings="write_race(center_is_used_write)") + + knl = lp.split_iname(knl, "i", 128, inner_tag="l.0", outer_tag="g.0") + return knl + # }}} @@ -603,25 +625,47 @@ class QBXFMMGeometryData(object): """Build a list of indices of QBX centers that use global QBX. This indexes into the global list of targets, (see :meth:`target_info`) of which the QBX centers occupy the first *ncenters*. + + Centers without any associated targets are excluded. """ tree = self.tree() + user_target_to_center = self.user_target_to_center() with cl.CommandQueue(self.cl_context) as queue: + logger.info("find global qbx centers: start") + + tgt_assoc_result = ( + user_target_to_center.with_queue(queue)[self.ncenters:]) + + center_is_used = cl.array.zeros(queue, self.ncenters, np.int8) + + self.code_getter.pick_used_centers( + queue, + center_is_used=center_is_used, + target_to_center=tgt_assoc_result, + ncenters=self.ncenters, + ntargets=len(tgt_assoc_result)) + from pyopencl.algorithm import copy_if - logger.info("find global qbx centers: start") result, count, _ = copy_if( cl.array.arange(queue, self.ncenters, tree.particle_id_dtype), - "global_qbx_flags[i] != 0", + "global_qbx_flags[i] != 0 && center_is_used[i] != 0", extra_args=[ - ("global_qbx_flags", self.global_qbx_flags()) + ("global_qbx_flags", self.global_qbx_flags()), + ("center_is_used", center_is_used) ], queue=queue) logger.info("find global qbx centers: done") + if self.debug: + logger.debug( + "find global qbx centers: using %d/%d centers" + % (int(count.get()), self.ncenters)) + return result[:count.get()].with_queue(None) @memoize_method