diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index a5a5d11504ed80f0973490e2ad5201d187108383..4a12d1c16d228ed715d45bc89d1b8a2057fba58b 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -317,17 +317,24 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): return conn + @property + @memoize_method + def tree_code_container(self): + from pytential.qbx.utils import TreeCodeContainer + return TreeCodeContainer(self.cl_context) + @property @memoize_method def refiner_code_container(self): from pytential.qbx.refinement import RefinerCodeContainer - return RefinerCodeContainer(self.cl_context) + return RefinerCodeContainer(self.cl_context, self.tree_code_container) @property @memoize_method def target_association_code_container(self): from pytential.qbx.target_assoc import TargetAssociationCodeContainer - return TargetAssociationCodeContainer(self.cl_context) + return TargetAssociationCodeContainer( + self.cl_context, self.tree_code_container) @memoize_method def with_refinement(self, target_order=None, kernel_length_scale=None, @@ -481,7 +488,8 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): def qbx_fmm_code_getter(self): from pytential.qbx.geometry import QBXFMMGeometryCodeGetter return QBXFMMGeometryCodeGetter(self.cl_context, self.ambient_dim, - debug=self.debug, _well_sep_is_n_away=self._well_sep_is_n_away) + self.tree_code_container, debug=self.debug, + _well_sep_is_n_away=self._well_sep_is_n_away) # {{{ fmm-based execution diff --git a/pytential/qbx/geometry.py b/pytential/qbx/geometry.py index 374d37d26e0cedbc36a0ceb2e114797a34ee4acc..55b40e2ed2c57dfc556fa038ba1244c1ad856f49 100644 --- a/pytential/qbx/geometry.py +++ b/pytential/qbx/geometry.py @@ -33,6 +33,9 @@ import loopy as lp from cgen import Enum +from pytential.qbx.utils import TreeCodeContainerMixin + + import logging logger = logging.getLogger(__name__) @@ -103,10 +106,12 @@ class target_state(Enum): # noqa FAILED = -2 -class QBXFMMGeometryCodeGetter(object): - def __init__(self, cl_context, ambient_dim, debug, _well_sep_is_n_away): +class QBXFMMGeometryCodeGetter(TreeCodeContainerMixin): + def __init__(self, cl_context, ambient_dim, tree_code_container, debug, + _well_sep_is_n_away): self.cl_context = cl_context self.ambient_dim = ambient_dim + self.tree_code_container = tree_code_container self.debug = debug self._well_sep_is_n_away = _well_sep_is_n_away @@ -129,12 +134,6 @@ class QBXFMMGeometryCodeGetter(object): knl = lp.tag_array_axes(knl, "targets", "stride:auto, stride:1") return lp.tag_inames(knl, dict(dim="ilp")) - @property - @memoize_method - def build_tree(self): - from boxtree import TreeBuilder - return TreeBuilder(self.cl_context) - @property @memoize_method def build_traversal(self): @@ -500,7 +499,7 @@ class QBXFMMGeometryData(object): refine_weights.finish() - tree, _ = code_getter.build_tree(queue, + tree, _ = code_getter.build_tree()(queue, particles=lpot_src.quad_stage2_density_discr.nodes(), targets=target_info.targets, target_radii=target_radii, diff --git a/pytential/qbx/refinement.py b/pytential/qbx/refinement.py index ce183e582cbb45e018e3e0d7ef90c6313b2dd000..c911fe983a19933ab2098df84e6cc4c585a675d8 100644 --- a/pytential/qbx/refinement.py +++ b/pytential/qbx/refinement.py @@ -35,7 +35,8 @@ from pytools import memoize_method from boxtree.area_query import AreaQueryElementwiseTemplate from boxtree.tools import InlineBinarySearch from pytential.qbx.utils import ( - QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS, TreeWranglerBase) + QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS, TreeWranglerBase, + TreeCodeContainerMixin) import logging logger = logging.getLogger(__name__) @@ -213,10 +214,11 @@ SUFFICIENT_SOURCE_QUADRATURE_RESOLUTION_CHECKER = AreaQueryElementwiseTemplate( # {{{ code container -class RefinerCodeContainer(object): +class RefinerCodeContainer(TreeCodeContainerMixin): - def __init__(self, cl_context): + def __init__(self, cl_context, tree_code_container): self.cl_context = cl_context + self.tree_code_container = tree_code_container @memoize_method def expansion_disk_undisturbed_by_sources_checker( @@ -257,16 +259,6 @@ class RefinerCodeContainer(object): knl = lp.split_iname(knl, "panel", 128, inner_tag="l.0", outer_tag="g.0") return knl - @memoize_method - def tree_builder(self): - from boxtree.tree_build import TreeBuilder - return TreeBuilder(self.cl_context) - - @memoize_method - def peer_list_finder(self): - from boxtree.area_query import PeerListFinder - return PeerListFinder(self.cl_context) - def get_wrangler(self, queue): """ :arg queue: diff --git a/pytential/qbx/target_assoc.py b/pytential/qbx/target_assoc.py index 9870195a9bfde66c0a148971d1ff9427e273301e..7b9736ce4b6d34f70dcb411bfcba387ea2ec7889 100644 --- a/pytential/qbx/target_assoc.py +++ b/pytential/qbx/target_assoc.py @@ -37,7 +37,8 @@ from boxtree.area_query import AreaQueryElementwiseTemplate from boxtree.tools import InlineBinarySearch from cgen import Enum from pytential.qbx.utils import ( - QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS, TreeWranglerBase) + QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS, TreeWranglerBase, + TreeCodeContainerMixin) unwrap_args = AreaQueryElementwiseTemplate.unwrap_args @@ -380,10 +381,11 @@ class QBXTargetAssociation(DeviceDataRecord): pass -class TargetAssociationCodeContainer(object): +class TargetAssociationCodeContainer(TreeCodeContainerMixin): - def __init__(self, cl_context): + def __init__(self, cl_context, tree_code_container): self.cl_context = cl_context + self.tree_code_container = tree_code_container @memoize_method def target_marker(self, dimensions, coord_dtype, box_id_dtype, @@ -421,21 +423,11 @@ class TargetAssociationCodeContainer(object): max_levels, extra_type_aliases=(("particle_id_t", particle_id_dtype),)) - @memoize_method - def peer_list_finder(self): - from boxtree.area_query import PeerListFinder - return PeerListFinder(self.cl_context) - @memoize_method def space_invader_query(self): from boxtree.area_query import SpaceInvaderQueryBuilder return SpaceInvaderQueryBuilder(self.cl_context) - @memoize_method - def tree_builder(self): - from boxtree.tree_build import TreeBuilder - return TreeBuilder(self.cl_context) - def get_wrangler(self, queue): return TargetAssociationWrangler(self, queue) diff --git a/pytential/qbx/utils.py b/pytential/qbx/utils.py index c27b9a870e2d83e1fee45282bb5b7fd4d8a2bf47..d03e8365fa188d55a04a78c41b56279af883a318 100644 --- a/pytential/qbx/utils.py +++ b/pytential/qbx/utils.py @@ -31,7 +31,7 @@ import numpy as np from boxtree.tree import Tree import pyopencl as cl import pyopencl.array # noqa -from pytools import memoize +from pytools import memoize, memoize_method import logging logger = logging.getLogger(__name__) @@ -137,13 +137,49 @@ def get_interleaved_radii(queue, lpot_source): # }}} -# {{{ peer list wrangler mixin +# {{{ tree code container + +class TreeCodeContainer(object): + + def __init__(self, cl_context): + self.cl_context = cl_context + + @memoize_method + def build_tree(self): + from boxtree.tree_build import TreeBuilder + return TreeBuilder(self.cl_context) + + @memoize_method + def peer_list_finder(self): + from boxtree.area_query import PeerListFinder + return PeerListFinder(self.cl_context) + +# }}} + + +# {{{ tree code container mixin + +class TreeCodeContainerMixin(object): + """Forwards requests for tree-related code to an inner code container named + self.tree_code_container. + """ + + def build_tree(self): + return self.tree_code_container.build_tree() + + def peer_list_finder(self): + return self.tree_code_container.peer_list_finder() + +# }}} + + +# {{{ tree wrangler base class class TreeWranglerBase(object): def build_tree(self, lpot_source, targets_list=(), use_stage2_discr=False): - tb = self.code_container.tree_builder() + tb = self.code_container.build_tree() from pytential.qbx.utils import build_tree_with_qbx_metadata return build_tree_with_qbx_metadata( self.queue, tb, lpot_source, targets_list=targets_list, diff --git a/test/test_global_qbx.py b/test/test_global_qbx.py index fe45746e79e6fe1e03738f278601946cf4260dca..cc0f26f0ebbce0cee1ab5c423fb92f6ed94b9c66 100644 --- a/test/test_global_qbx.py +++ b/test/test_global_qbx.py @@ -94,6 +94,8 @@ def run_source_refinement_test(ctx_getter, mesh, order, helmholtz_k=None): from pytential.qbx.refinement import ( RefinerCodeContainer, refine_for_global_qbx) + from pytential.qbx.utils import TreeCodeContainer + lpot_source = QBXLayerPotentialSource(discr, order) del discr @@ -105,7 +107,9 @@ def run_source_refinement_test(ctx_getter, mesh, order, helmholtz_k=None): refiner_extra_kwargs["kernel_length_scale"] = 5/helmholtz_k lpot_source, conn = refine_for_global_qbx( - lpot_source, RefinerCodeContainer(cl_ctx).get_wrangler(queue), + lpot_source, + RefinerCodeContainer( + cl_ctx, TreeCodeContainer(cl_ctx)).get_wrangler(queue), factory, **refiner_extra_kwargs) from pytential.qbx.utils import get_centers_on_side @@ -294,7 +298,10 @@ def test_target_association(ctx_getter, curve_name, curve_f, nelements): from pytential.qbx.target_assoc import ( TargetAssociationCodeContainer, associate_targets_to_qbx_centers) - code_container = TargetAssociationCodeContainer(cl_ctx) + from pytential.qbx.utils import TreeCodeContainer + + code_container = TargetAssociationCodeContainer( + cl_ctx, TreeCodeContainer(cl_ctx)) target_assoc = (associate_targets_to_qbx_centers( lpot_source, @@ -401,7 +408,10 @@ def test_target_association_failure(ctx_getter): TargetAssociationCodeContainer, associate_targets_to_qbx_centers, QBXTargetAssociationFailedException) - code_container = TargetAssociationCodeContainer(cl_ctx) + from pytential.qbx.utils import TreeCodeContainer + + code_container = TargetAssociationCodeContainer( + cl_ctx, TreeCodeContainer(cl_ctx)) with pytest.raises(QBXTargetAssociationFailedException): associate_targets_to_qbx_centers(