diff --git a/pytential/qbx/geometry.py b/pytential/qbx/geometry.py index 71160948543fcdc225be382d2ed2ffcf5ef7aec3..e9e7fd928ec702f5ad9167590e82322c4f55996d 100644 --- a/pytential/qbx/geometry.py +++ b/pytential/qbx/geometry.py @@ -781,8 +781,9 @@ class QBXFMMGeometryData(object): nqbx_centers = self.ncenters flags[:nqbx_centers] = 0 - from boxtree.tree import filter_target_lists_in_tree_order - result = filter_target_lists_in_tree_order(queue, self.tree(), flags) + tree = self.tree() + plfilt = self.code_getter.particle_list_filter() + result = plfilt.filter_target_lists_in_tree_order(queue, tree, flags) logger.info("find non-qbx box target lists: done") diff --git a/pytential/qbx/utils.py b/pytential/qbx/utils.py index d03e8365fa188d55a04a78c41b56279af883a318..ecb939de841d7b567b735d52ff26b327a72d50b1 100644 --- a/pytential/qbx/utils.py +++ b/pytential/qbx/utils.py @@ -154,6 +154,11 @@ class TreeCodeContainer(object): from boxtree.area_query import PeerListFinder return PeerListFinder(self.cl_context) + @memoize_method + def particle_list_filter(self): + from boxtree.tree import ParticleListFilter + return ParticleListFilter(self.cl_context) + # }}} @@ -170,6 +175,9 @@ class TreeCodeContainerMixin(object): def peer_list_finder(self): return self.tree_code_container.peer_list_finder() + def particle_list_filter(self): + return self.tree_code_container.particle_list_filter() + # }}} @@ -180,9 +188,10 @@ class TreeWranglerBase(object): def build_tree(self, lpot_source, targets_list=(), use_stage2_discr=False): tb = self.code_container.build_tree() + plfilt = self.code_container.particle_list_filter() from pytential.qbx.utils import build_tree_with_qbx_metadata return build_tree_with_qbx_metadata( - self.queue, tb, lpot_source, targets_list=targets_list, + self.queue, tb, plfilt, lpot_source, targets_list=targets_list, use_stage2_discr=use_stage2_discr) def find_peer_lists(self, tree): @@ -448,7 +457,7 @@ MAX_REFINE_WEIGHT = 64 def build_tree_with_qbx_metadata( - queue, tree_builder, lpot_source, targets_list=(), + queue, tree_builder, particle_list_filter, lpot_source, targets_list=(), use_stage2_discr=False): """Return a :class:`TreeWithQBXMetadata` built from the given layer potential source. This contains particles of four different types: @@ -542,9 +551,9 @@ def build_tree_with_qbx_metadata( flags[particle_slice].fill(1) flags.finish() - from boxtree.tree import filter_target_lists_in_user_order box_to_class = ( - filter_target_lists_in_user_order(queue, tree, flags) + particle_list_filter + .filter_target_lists_in_user_order(queue, tree, flags) .with_queue(queue)) if fixup: