diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index 2ab1488a02198022c39bec95566cd9d77f6b90f2..0b94f90b8ad45f736582f3e0b684226d11ebcbac 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -279,6 +279,12 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): from pytential.qbx.refinement import RefinerCodeContainer return RefinerCodeContainer(self.cl_context) + @property + @memoize_method + def target_association_code_container(self): + from pytential.qbx.target_assoc import TargetAssociationCodeContainer + return TargetAssociationCodeContainer(self.cl_context) + @memoize_method def with_refinement(self, target_order=None, kernel_length_scale=None, maxiter=10): @@ -296,12 +302,13 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): if target_order is None: target_order = self.density_discr.groups[0].order - lpot, connection = refine_for_global_qbx( - self, - self.refiner_code_container, - InterpolatoryQuadratureSimplexGroupFactory(target_order), - kernel_length_scale=kernel_length_scale, - maxiter=maxiter) + with cl.CommandQueue(self.cl_context) as queue: + lpot, connection = refine_for_global_qbx( + self, + self.refiner_code_container.get_wrangler(queue), + InterpolatoryQuadratureSimplexGroupFactory(target_order), + kernel_length_scale=kernel_length_scale, + maxiter=maxiter) return lpot, connection @@ -716,7 +723,8 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): # First ncenters targets are the centers tgt_to_qbx_center = ( geo_data.user_target_to_center()[geo_data.ncenters:] - .copy(queue=queue)) + .copy(queue=queue) + .with_queue(queue)) qbx_tgt_numberer = self.get_qbx_target_numberer( tgt_to_qbx_center.dtype) diff --git a/pytential/qbx/geometry.py b/pytential/qbx/geometry.py index 53f4de33eeab5c8ad618651bb6bd066b4f676ea3..b08370dfe697b2acc482256a2770d42c5d46d2a6 100644 --- a/pytential/qbx/geometry.py +++ b/pytential/qbx/geometry.py @@ -665,11 +665,7 @@ class QBXFMMGeometryData(object): Shape: ``[ntargets]`` of :attr:`boxtree.Tree.particle_id_dtype`, with extra values from :class:`target_state` allowed. Targets occur in user order. """ - from pytential.qbx.target_assoc import QBXTargetAssociator - - # FIXME: kernel ownership... - tgt_assoc = QBXTargetAssociator(self.cl_context) - + from pytential.qbx.target_assoc import associate_targets_to_qbx_centers tgt_info = self.target_info() from pytential.target import PointsTarget @@ -678,24 +674,28 @@ class QBXFMMGeometryData(object): target_side_prefs = (self .target_side_preferences()[self.ncenters:].get(queue=queue)) - target_discrs_and_qbx_sides = [( - PointsTarget(tgt_info.targets[:, self.ncenters:]), - target_side_prefs.astype(np.int32))] + target_discrs_and_qbx_sides = [( + PointsTarget(tgt_info.targets[:, self.ncenters:]), + target_side_prefs.astype(np.int32))] - # FIXME: try block... - tgt_assoc_result = tgt_assoc(self.lpot_source, - target_discrs_and_qbx_sides, - target_association_tolerance=( - self.target_association_tolerance)) + target_association_wrangler = ( + self.lpot_source.target_association_code_container + .get_wrangler(queue)) - tree = self.tree() + tgt_assoc_result = associate_targets_to_qbx_centers( + self.lpot_source, + target_association_wrangler, + target_discrs_and_qbx_sides, + target_association_tolerance=( + self.target_association_tolerance)) + + tree = self.tree() - with cl.CommandQueue(self.cl_context) as queue: result = cl.array.empty(queue, tgt_info.ntargets, tree.particle_id_dtype) result[:self.ncenters].fill(target_state.NO_QBX_NEEDED) result[self.ncenters:] = tgt_assoc_result.target_to_center - return result + return result.with_queue(None) @memoize_method def center_to_tree_targets(self): diff --git a/pytential/qbx/refinement.py b/pytential/qbx/refinement.py index dc8a784408c3422a85974d15a384829aae040559..2ab598d85ba17e38c2533e9486bc597756a89156 100644 --- a/pytential/qbx/refinement.py +++ b/pytential/qbx/refinement.py @@ -34,7 +34,8 @@ import pyopencl as cl from pytools import memoize_method from boxtree.area_query import AreaQueryElementwiseTemplate from boxtree.tools import InlineBinarySearch -from pytential.qbx.utils import QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS +from pytential.qbx.utils import ( + QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS, TreeWranglerBase) import logging logger = logging.getLogger(__name__) @@ -264,7 +265,7 @@ class RefinerCodeContainer(object): # {{{ wrangler -class RefinerWrangler(object): +class RefinerWrangler(TreeWranglerBase): def __init__(self, code_container, queue): self.code_container = code_container @@ -420,18 +421,6 @@ class RefinerWrangler(object): # }}} - def build_tree(self, lpot_source, use_base_fine_discr=False): - tb = self.code_container.tree_builder() - from pytential.qbx.utils import build_tree_with_qbx_metadata - return build_tree_with_qbx_metadata( - self.queue, tb, lpot_source, use_base_fine_discr=use_base_fine_discr) - - def find_peer_lists(self, tree): - plf = self.code_container.peer_list_finder() - peer_lists, evt = plf(self.queue, tree) - cl.wait_for_events([evt]) - return peer_lists - def refine(self, density_discr, refiner, refine_flags, factory, debug): """ Refine the underlying mesh and discretization. @@ -476,7 +465,7 @@ def make_empty_refine_flags(queue, lpot_source, use_base_fine_discr=False): # {{{ main entry point -def refine_for_global_qbx(lpot_source, code_container, +def refine_for_global_qbx(lpot_source, wrangler, group_factory, kernel_length_scale=None, # FIXME: Set debug=False once everything works. refine_flags=None, debug=True, maxiter=50): @@ -485,7 +474,7 @@ def refine_for_global_qbx(lpot_source, code_container, :arg lpot_source: An instance of :class:`QBXLayerPotentialSource`. - :arg code_container: An instance of :class:`RefinerCodeContainer`. + :arg wrangler: An instance of :class:`RefinerWrangler`. :arg group_factory: An instance of :class:`meshmode.mesh.discretization.ElementGroupFactory`. Used for @@ -514,111 +503,108 @@ def refine_for_global_qbx(lpot_source, code_container, from meshmode.discretization.connection import ( ChainedDiscretizationConnection, make_same_mesh_connection) - with cl.CommandQueue(lpot_source.cl_context) as queue: - wrangler = code_container.get_wrangler(queue) - - refiner = Refiner(lpot_source.density_discr.mesh) - connections = [] - - # Do initial refinement. - if refine_flags is not None: + refiner = Refiner(lpot_source.density_discr.mesh) + connections = [] + + # Do initial refinement. + if refine_flags is not None: + conn = wrangler.refine( + lpot_source.density_discr, refiner, refine_flags, group_factory, + debug) + connections.append(conn) + lpot_source = lpot_source.copy(density_discr=conn.to_discr) + + # {{{ first stage refinement + + must_refine = True + niter = 0 + + while must_refine: + must_refine = False + niter += 1 + + if niter > maxiter: + from warnings import warn + warn( + "Max iteration count reached in QBX layer potential source" + " refiner.", + RefinerNotConvergedWarning) + break + + # Build tree and auxiliary data. + # FIXME: The tree should not have to be rebuilt at each iteration. + tree = wrangler.build_tree(lpot_source) + peer_lists = wrangler.find_peer_lists(tree) + refine_flags = make_empty_refine_flags(wrangler.queue, lpot_source) + + # Check condition 1. + must_refine |= wrangler.check_expansion_disks_undisturbed_by_sources( + lpot_source, tree, peer_lists, refine_flags, debug) + + # Check condition 3. + if kernel_length_scale is not None: + must_refine |= ( + wrangler.check_kernel_length_scale_to_panel_size_ratio( + lpot_source, kernel_length_scale, refine_flags, debug)) + + if must_refine: conn = wrangler.refine( - lpot_source.density_discr, refiner, refine_flags, group_factory, - debug) + lpot_source.density_discr, refiner, refine_flags, + group_factory, debug) connections.append(conn) lpot_source = lpot_source.copy(density_discr=conn.to_discr) - # {{{ first stage refinement - - must_refine = True - niter = 0 - - while must_refine: - must_refine = False - niter += 1 - - if niter > maxiter: - from warnings import warn - warn( - "Max iteration count reached in QBX layer potential source" - " refiner.", - RefinerNotConvergedWarning) - break - - # Build tree and auxiliary data. - # FIXME: The tree should not have to be rebuilt at each iteration. - tree = wrangler.build_tree(lpot_source) - peer_lists = wrangler.find_peer_lists(tree) - refine_flags = make_empty_refine_flags(queue, lpot_source) - - # Check condition 1. - must_refine |= wrangler.check_expansion_disks_undisturbed_by_sources( - lpot_source, tree, peer_lists, refine_flags, debug) - - # Check condition 3. - if kernel_length_scale is not None: - must_refine |= ( - wrangler.check_kernel_length_scale_to_panel_size_ratio( - lpot_source, kernel_length_scale, refine_flags, debug)) - - if must_refine: - conn = wrangler.refine( - lpot_source.density_discr, refiner, refine_flags, - group_factory, debug) - connections.append(conn) - lpot_source = lpot_source.copy(density_discr=conn.to_discr) - - del tree - del refine_flags - del peer_lists - - # }}} - - # {{{ second stage refinement - - must_refine = True - niter = 0 - fine_connections = [] - - base_fine_density_discr = lpot_source.density_discr - - while must_refine: - must_refine = False - niter += 1 - - if niter > maxiter: - from warnings import warn - warn( - "Max iteration count reached in QBX layer potential source" - " refiner.", - RefinerNotConvergedWarning) - break - - # Build tree and auxiliary data. - # FIXME: The tree should not have to be rebuilt at each iteration. - tree = wrangler.build_tree(lpot_source, use_base_fine_discr=True) - peer_lists = wrangler.find_peer_lists(tree) - refine_flags = make_empty_refine_flags( - queue, lpot_source, use_base_fine_discr=True) - - must_refine |= wrangler.check_sufficient_source_quadrature_resolution( - lpot_source, tree, peer_lists, refine_flags, debug) - - if must_refine: - conn = wrangler.refine( - base_fine_density_discr, - refiner, refine_flags, group_factory, debug) - base_fine_density_discr = conn.to_discr - fine_connections.append(conn) - lpot_source = lpot_source.copy( - base_resampler=ChainedDiscretizationConnection( - fine_connections)) - - del tree - del refine_flags - del peer_lists - - # }}} + del tree + del refine_flags + del peer_lists + + # }}} + + # {{{ second stage refinement + + must_refine = True + niter = 0 + fine_connections = [] + + base_fine_density_discr = lpot_source.density_discr + + while must_refine: + must_refine = False + niter += 1 + + if niter > maxiter: + from warnings import warn + warn( + "Max iteration count reached in QBX layer potential source" + " refiner.", + RefinerNotConvergedWarning) + break + + # Build tree and auxiliary data. + # FIXME: The tree should not have to be rebuilt at each iteration. + tree = wrangler.build_tree(lpot_source, use_base_fine_discr=True) + peer_lists = wrangler.find_peer_lists(tree) + refine_flags = make_empty_refine_flags( + wrangler.queue, lpot_source, use_base_fine_discr=True) + + must_refine |= wrangler.check_sufficient_source_quadrature_resolution( + lpot_source, tree, peer_lists, refine_flags, debug) + + if must_refine: + conn = wrangler.refine( + base_fine_density_discr, + refiner, refine_flags, group_factory, debug) + base_fine_density_discr = conn.to_discr + fine_connections.append(conn) + lpot_source = lpot_source.copy( + base_resampler=ChainedDiscretizationConnection( + fine_connections)) + + del tree + del refine_flags + del peer_lists + + # }}} lpot_source = lpot_source.copy(debug=debug, _refined_for_global_qbx=True) diff --git a/pytential/qbx/target_assoc.py b/pytential/qbx/target_assoc.py index 5e1bcb1eb88fb7ed97ace3175c451beb8631dac8..946fe96b9ca7ad73f7685021ca2f128ea8bd5359 100644 --- a/pytential/qbx/target_assoc.py +++ b/pytential/qbx/target_assoc.py @@ -37,7 +37,7 @@ from boxtree.area_query import AreaQueryElementwiseTemplate from boxtree.tools import InlineBinarySearch from cgen import Enum from pytential.qbx.utils import ( - QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS) + QBX_TREE_C_PREAMBLE, QBX_TREE_MAKO_DEFS, TreeWranglerBase) unwrap_args = AreaQueryElementwiseTemplate.unwrap_args @@ -324,89 +324,96 @@ class QBXTargetAssociation(DeviceDataRecord): pass -class QBXTargetAssociator(object): +class TargetAssociationCodeContainer(object): def __init__(self, cl_context): - from boxtree.tree_build import TreeBuilder - self.tree_builder = TreeBuilder(cl_context) self.cl_context = cl_context - from boxtree.area_query import PeerListFinder, SpaceInvaderQueryBuilder - self.peer_list_finder = PeerListFinder(cl_context) - self.space_invader_query = SpaceInvaderQueryBuilder(cl_context) - - # {{{ kernel generation @memoize_method - def get_qbx_target_marker(self, - dimensions, - coord_dtype, - box_id_dtype, - peer_list_idx_dtype, - particle_id_dtype, - max_levels): + def target_marker(self, dimensions, coord_dtype, box_id_dtype, + peer_list_idx_dtype, particle_id_dtype, max_levels): return QBX_TARGET_MARKER.generate( - self.cl_context, - dimensions, - coord_dtype, - box_id_dtype, - peer_list_idx_dtype, - max_levels, - extra_type_aliases=(("particle_id_t", particle_id_dtype),)) + self.cl_context, + dimensions, + coord_dtype, + box_id_dtype, + peer_list_idx_dtype, + max_levels, + extra_type_aliases=(("particle_id_t", particle_id_dtype),)) @memoize_method - def get_qbx_center_finder(self, - dimensions, - coord_dtype, - box_id_dtype, - peer_list_idx_dtype, - particle_id_dtype, - max_levels): + def center_finder(self, dimensions, coord_dtype, box_id_dtype, + peer_list_idx_dtype, particle_id_dtype, max_levels): return QBX_CENTER_FINDER.generate( - self.cl_context, - dimensions, - coord_dtype, - box_id_dtype, - peer_list_idx_dtype, - max_levels, - extra_type_aliases=(("particle_id_t", particle_id_dtype),)) + self.cl_context, + dimensions, + coord_dtype, + box_id_dtype, + peer_list_idx_dtype, + max_levels, + extra_type_aliases=(("particle_id_t", particle_id_dtype),)) @memoize_method - def get_qbx_failed_target_association_refiner(self, dimensions, coord_dtype, - box_id_dtype, peer_list_idx_dtype, - particle_id_dtype, max_levels): + def refiner_for_failed_target_association(self, dimensions, coord_dtype, + box_id_dtype, peer_list_idx_dtype, particle_id_dtype, max_levels): return QBX_FAILED_TARGET_ASSOCIATION_REFINER.generate( - self.cl_context, - dimensions, - coord_dtype, - box_id_dtype, - peer_list_idx_dtype, - max_levels, - extra_type_aliases=(("particle_id_t", particle_id_dtype),)) + self.cl_context, + dimensions, + coord_dtype, + box_id_dtype, + peer_list_idx_dtype, + max_levels, + extra_type_aliases=(("particle_id_t", particle_id_dtype),)) + + @memoize_method + def peer_list_finder(self): + from boxtree.area_query import PeerListFinder + return PeerListFinder(self.cl_context) + + @memoize_method + def space_invader_query(self): + from boxtree.area_query import SpaceInvaderQueryBuilder + return SpaceInvaderQueryBuilder(self.cl_context) + + @memoize_method + def tree_builder(self): + from boxtree.tree_build import TreeBuilder + return TreeBuilder(self.cl_context) - # }}} + def get_wrangler(self, queue): + return TargetAssociationWrangler(self, queue) - def mark_targets(self, queue, tree, peer_lists, lpot_source, target_status, + +class TargetAssociationWrangler(TreeWranglerBase): + + def __init__(self, code_container, queue): + self.code_container = code_container + self.queue = queue + + def mark_targets(self, tree, peer_lists, lpot_source, target_status, debug, wait_for=None): # Round up level count--this gets included in the kernel as # a stack bound. Rounding avoids too many kernel versions. from pytools import div_ceil max_levels = 10 * div_ceil(tree.nlevels, 10) - knl = self.get_qbx_target_marker( + knl = self.code_container.target_marker( tree.dimensions, tree.coord_dtype, tree.box_id_dtype, peer_lists.peer_list_starts.dtype, tree.particle_id_dtype, max_levels) - found_target_close_to_panel = cl.array.zeros(queue, 1, np.int32) + found_target_close_to_panel = cl.array.zeros(self.queue, 1, np.int32) found_target_close_to_panel.finish() # Perform a space invader query over the sources. source_slice = tree.sorted_target_ids[tree.qbx_user_source_slice] - sources = [axis.with_queue(queue)[source_slice] for axis in tree.sources] - tunnel_radius_by_source = \ - lpot_source._close_target_tunnel_radius("nsources").with_queue(queue) + sources = [ + axis.with_queue(self.queue)[source_slice] for axis in tree.sources] + tunnel_radius_by_source = ( + lpot_source._close_target_tunnel_radius("nsources") + .with_queue(self.queue)) # Target-marking algorithm (TGTMARK): # @@ -433,8 +440,8 @@ class QBXTargetAssociator(object): # sources are fixed (which sort of makes sense, given that the number # of targets per box is not bounded). - box_to_search_dist, evt = self.space_invader_query( - queue, + box_to_search_dist, evt = self.code_container.space_invader_query()( + self.queue, tree, sources, tunnel_radius_by_source, @@ -460,7 +467,7 @@ class QBXTargetAssociator(object): found_target_close_to_panel, *tree.sources), range=slice(tree.nqbxtargets), - queue=queue, + queue=self.queue, wait_for=wait_for) if debug: @@ -476,7 +483,7 @@ class QBXTargetAssociator(object): return (found_target_close_to_panel == 1).all().get() - def try_find_centers(self, queue, tree, peer_lists, lpot_source, + def try_find_centers(self, tree, peer_lists, lpot_source, target_status, target_flags, target_assoc, target_association_tolerance, debug, wait_for=None): # Round up level count--this gets included in the kernel as @@ -484,7 +491,7 @@ class QBXTargetAssociator(object): from pytools import div_ceil max_levels = 10 * div_ceil(tree.nlevels, 10) - knl = self.get_qbx_center_finder( + knl = self.code_container.center_finder( tree.dimensions, tree.coord_dtype, tree.box_id_dtype, peer_lists.peer_list_starts.dtype, @@ -496,11 +503,13 @@ class QBXTargetAssociator(object): marked_target_count = int(cl.array.sum(target_status).get()) # Perform a space invader query over the centers. - center_slice = \ - tree.sorted_target_ids[tree.qbx_user_center_slice].with_queue(queue) - centers = [axis.with_queue(queue)[center_slice] for axis in tree.sources] + center_slice = ( + tree.sorted_target_ids[tree.qbx_user_center_slice] + .with_queue(self.queue)) + centers = [ + axis.with_queue(self.queue)[center_slice] for axis in tree.sources] expansion_radii_by_center = \ - lpot_source._expansion_radii("ncenters").with_queue(queue) + lpot_source._expansion_radii("ncenters").with_queue(self.queue) expansion_radii_by_center_with_tolerance = \ expansion_radii_by_center * (1 + target_association_tolerance) @@ -510,8 +519,8 @@ class QBXTargetAssociator(object): # (2) Area query from targets with those radii to find closest eligible # center. - box_to_search_dist, evt = self.space_invader_query( - queue, + box_to_search_dist, evt = self.code_container.space_invader_query()( + self.queue, tree, centers, expansion_radii_by_center_with_tolerance, @@ -520,7 +529,7 @@ class QBXTargetAssociator(object): wait_for = [evt] min_dist_to_center = cl.array.empty( - queue, tree.nqbxtargets, tree.coord_dtype) + self.queue, tree.nqbxtargets, tree.coord_dtype) min_dist_to_center.fill(np.inf) wait_for.extend(min_dist_to_center.events) @@ -543,7 +552,7 @@ class QBXTargetAssociator(object): min_dist_to_center, *tree.sources), range=slice(tree.nqbxtargets), - queue=queue, + queue=self.queue, wait_for=wait_for) if debug: @@ -559,7 +568,7 @@ class QBXTargetAssociator(object): logger.info("target association: done finding centers for targets") return - def mark_panels_for_refinement(self, queue, tree, peer_lists, lpot_source, + def mark_panels_for_refinement(self, tree, peer_lists, lpot_source, target_status, refine_flags, debug, wait_for=None): # Round up level count--this gets included in the kernel as @@ -567,26 +576,28 @@ class QBXTargetAssociator(object): from pytools import div_ceil max_levels = 10 * div_ceil(tree.nlevels, 10) - knl = self.get_qbx_failed_target_association_refiner( + knl = self.code_container.refiner_for_failed_target_association( tree.dimensions, tree.coord_dtype, tree.box_id_dtype, peer_lists.peer_list_starts.dtype, tree.particle_id_dtype, max_levels) - found_panel_to_refine = cl.array.zeros(queue, 1, np.int32) + found_panel_to_refine = cl.array.zeros(self.queue, 1, np.int32) found_panel_to_refine.finish() # Perform a space invader query over the sources. source_slice = tree.user_source_ids[tree.qbx_user_source_slice] - sources = [axis.with_queue(queue)[source_slice] for axis in tree.sources] - tunnel_radius_by_source = \ - lpot_source._close_target_tunnel_radius("nsources").with_queue(queue) + sources = [ + axis.with_queue(self.queue)[source_slice] for axis in tree.sources] + tunnel_radius_by_source = ( + lpot_source._close_target_tunnel_radius("nsources") + .with_queue(self.queue)) # See (TGTMARK) above for algorithm. - box_to_search_dist, evt = self.space_invader_query( - queue, + box_to_search_dist, evt = self.code_container.space_invader_query()( + self.queue, tree, sources, tunnel_radius_by_source, @@ -613,7 +624,7 @@ class QBXTargetAssociator(object): found_panel_to_refine, *tree.sources), range=slice(tree.nqbxtargets), - queue=queue, + queue=self.queue, wait_for=wait_for) if debug: @@ -629,9 +640,9 @@ class QBXTargetAssociator(object): return (found_panel_to_refine == 1).all().get() - def make_target_flags(self, queue, target_discrs_and_qbx_sides): + def make_target_flags(self, target_discrs_and_qbx_sides): ntargets = sum(discr.nnodes for discr, _ in target_discrs_and_qbx_sides) - target_flags = cl.array.empty(queue, ntargets, dtype=np.int32) + target_flags = cl.array.empty(self.queue, ntargets, dtype=np.int32) offset = 0 for discr, flags in target_discrs_and_qbx_sides: @@ -645,118 +656,112 @@ class QBXTargetAssociator(object): target_flags.finish() return target_flags - def make_default_target_association(self, queue, ntargets): - target_to_center = cl.array.empty(queue, ntargets, dtype=np.int32) + def make_default_target_association(self, ntargets): + target_to_center = cl.array.empty(self.queue, ntargets, dtype=np.int32) target_to_center.fill(-1) target_to_center.finish() return QBXTargetAssociation(target_to_center=target_to_center) - def __call__(self, lpot_source, target_discrs_and_qbx_sides, - target_association_tolerance, debug=True, wait_for=None): - """ - Entry point for calling the target associator. - :arg lpot_source: An instance of :class:`NewQBXLayerPotentialSource` +def associate_targets_to_qbx_centers(lpot_source, wrangler, + target_discrs_and_qbx_sides, target_association_tolerance, + debug=True, wait_for=None): + """ + Entry point for calling the target associator. + + :arg lpot_source: An instance of :class:`QBXLayerPotentialSource` + + :arg wrangler: An instance of :class:`TargetAssociationWrangler` - :arg target_discrs_and_qbx_sides: + :arg target_discrs_and_qbx_sides: - a list of tuples ``(discr, sides)``, where - *discr* is a - :class:`pytential.discretization.Discretization` - or a - :class:`pytential.discretization.target.TargetBase` instance, and - *sides* is either a :class:`int` or - an array of (:class:`numpy.int8`) side requests for each - target. + a list of tuples ``(discr, sides)``, where + *discr* is a + :class:`pytential.discretization.Discretization` + or a + :class:`pytential.discretization.target.TargetBase` instance, and + *sides* is either a :class:`int` or + an array of (:class:`numpy.int8`) side requests for each + target. - The side request can take the following values for each target: + The side request can take the following values for each target: - ===== ============================================== - Value Meaning - ===== ============================================== - 0 Volume target. If near a QBX center, - the value from the QBX expansion is returned, - otherwise the volume potential is returned. + ===== ============================================== + Value Meaning + ===== ============================================== + 0 Volume target. If near a QBX center, + the value from the QBX expansion is returned, + otherwise the volume potential is returned. - -1 Surface target. Return interior limit from - interior-side QBX expansion. + -1 Surface target. Return interior limit from + interior-side QBX expansion. - +1 Surface target. Return exterior limit from - exterior-side QBX expansion. + +1 Surface target. Return exterior limit from + exterior-side QBX expansion. - -2 Volume target. If within an *interior* QBX disk, - the value from the QBX expansion is returned, - otherwise the volume potential is returned. + -2 Volume target. If within an *interior* QBX disk, + the value from the QBX expansion is returned, + otherwise the volume potential is returned. - +2 Volume target. If within an *exterior* QBX disk, - the value from the QBX expansion is returned, - otherwise the volume potential is returned. - ===== ============================================== + +2 Volume target. If within an *exterior* QBX disk, + the value from the QBX expansion is returned, + otherwise the volume potential is returned. + ===== ============================================== - :raises QBXTargetAssociationFailedException: - when target association failed to find a center for a target. - The returned exception object contains suggested refine flags. + :raises QBXTargetAssociationFailedException: + when target association failed to find a center for a target. + The returned exception object contains suggested refine flags. - :returns: - """ + :returns: A :class:`QBXTargetAssociation`. + """ - with cl.CommandQueue(self.cl_context) as queue: - from pytential.qbx.utils import build_tree_with_qbx_metadata + tree = wrangler.build_tree(lpot_source, + [discr for discr, _ in target_discrs_and_qbx_sides]) - tree = build_tree_with_qbx_metadata( - queue, - self.tree_builder, - lpot_source, - [discr for discr, _ in target_discrs_and_qbx_sides]) + peer_lists = wrangler.find_peer_lists(tree) - peer_lists, evt = self.peer_list_finder(queue, tree, wait_for) - wait_for = [evt] + target_status = cl.array.zeros(wrangler.queue, tree.nqbxtargets, dtype=np.int32) + target_status.finish() - target_status = cl.array.zeros(queue, tree.nqbxtargets, dtype=np.int32) - target_status.finish() + have_close_targets = wrangler.mark_targets(tree, peer_lists, + lpot_source, target_status, debug) - have_close_targets = self.mark_targets(queue, tree, peer_lists, - lpot_source, target_status, - debug) + target_assoc = wrangler.make_default_target_association(tree.nqbxtargets) - target_assoc = self.make_default_target_association( - queue, tree.nqbxtargets) + if not have_close_targets: + return target_assoc.with_queue(None) - if not have_close_targets: - return target_assoc.with_queue(None) + target_flags = wrangler.make_target_flags(target_discrs_and_qbx_sides) - target_flags = self.make_target_flags(queue, target_discrs_and_qbx_sides) + wrangler.try_find_centers(tree, peer_lists, lpot_source, target_status, + target_flags, target_assoc, target_association_tolerance, debug) - self.try_find_centers(queue, tree, peer_lists, lpot_source, - target_status, target_flags, target_assoc, - target_association_tolerance, debug) + center_not_found = ( + target_status == target_status_enum.MARKED_QBX_CENTER_PENDING) - center_not_found = ( - target_status == target_status_enum.MARKED_QBX_CENTER_PENDING) + if center_not_found.any().get(): + surface_target = ( + (target_flags == target_flag_enum.INTERIOR_SURFACE_TARGET) + | (target_flags == target_flag_enum.EXTERIOR_SURFACE_TARGET)) - if center_not_found.any().get(): - surface_target = ( - (target_flags == target_flag_enum.INTERIOR_SURFACE_TARGET) - | (target_flags == target_flag_enum.EXTERIOR_SURFACE_TARGET)) + if (center_not_found & surface_target).any().get(): + logger.warning("An on-surface target was not " + "assigned a center. As a remedy you can try increasing " + "the \"target_association_tolerance\" parameter, but " + "this could also cause an invalid center assignment.") - if (center_not_found & surface_target).any().get(): - logger.warning("An on-surface target was not " - "assigned a center. As a remedy you can try increasing " - "the \"target_association_tolerance\" parameter, but " - "this could also cause an invalid center assignment.") + refine_flags = cl.array.zeros( + wrangler.queue, tree.nqbxpanels, dtype=np.int32) + have_panel_to_refine = wrangler.mark_panels_for_refinement( + tree, peer_lists, lpot_source, target_status, refine_flags, debug) - refine_flags = cl.array.zeros(queue, tree.nqbxpanels, dtype=np.int32) - have_panel_to_refine = self.mark_panels_for_refinement(queue, - tree, peer_lists, - lpot_source, target_status, - refine_flags, debug) - assert have_panel_to_refine - raise QBXTargetAssociationFailedException( - refine_flags=refine_flags.with_queue(None), - failed_target_flags=center_not_found.with_queue(None)) + assert have_panel_to_refine + raise QBXTargetAssociationFailedException( + refine_flags=refine_flags.with_queue(None), + failed_target_flags=center_not_found.with_queue(None)) - return target_assoc.with_queue(None) + return target_assoc.with_queue(None) # }}} diff --git a/pytential/qbx/utils.py b/pytential/qbx/utils.py index 02a19e007eae86e90066224f97b0b261a02a7e6a..edf92dfe1af8a1af70042b761999eed0ed8a6eb2 100644 --- a/pytential/qbx/utils.py +++ b/pytential/qbx/utils.py @@ -137,6 +137,27 @@ def get_interleaved_radii(queue, lpot_source): # }}} +# {{{ peer list wrangler mixin + +class TreeWranglerBase(object): + + def build_tree(self, lpot_source, targets_list=(), + use_base_fine_discr=False): + tb = self.code_container.tree_builder() + from pytential.qbx.utils import build_tree_with_qbx_metadata + return build_tree_with_qbx_metadata( + self.queue, tb, lpot_source, targets_list=targets_list, + use_base_fine_discr=use_base_fine_discr) + + def find_peer_lists(self, tree): + plf = self.code_container.peer_list_finder() + peer_lists, evt = plf(self.queue, tree) + cl.wait_for_events([evt]) + return peer_lists + +# }}} + + # {{{ panel sizes def panel_sizes(discr, last_dim_length): diff --git a/test/test_global_qbx.py b/test/test_global_qbx.py index 6065d183854cd49f5b991e9d5777f5c5fc30052f..12c816f4a15d771a1ff94502e4bb516ba5d2999c 100644 --- a/test/test_global_qbx.py +++ b/test/test_global_qbx.py @@ -102,7 +102,7 @@ def run_source_refinement_test(ctx_getter, mesh, order, helmholtz_k=None): refiner_extra_kwargs["kernel_length_scale"] = 5/helmholtz_k lpot_source, conn = refine_for_global_qbx( - lpot_source, RefinerCodeContainer(cl_ctx), + lpot_source, RefinerCodeContainer(cl_ctx).get_wrangler(queue), factory, **refiner_extra_kwargs) from pytential.qbx.utils import get_centers_on_side @@ -287,9 +287,15 @@ def test_target_association(ctx_getter, curve_name, curve_f, nelements): # {{{ run target associator and check - from pytential.qbx.target_assoc import QBXTargetAssociator - target_assoc = ( - QBXTargetAssociator(cl_ctx)(lpot_source, target_discrs, + from pytential.qbx.target_assoc import ( + TargetAssociationCodeContainer, associate_targets_to_qbx_centers) + + code_container = TargetAssociationCodeContainer(cl_ctx) + + target_assoc = (associate_targets_to_qbx_centers( + lpot_source, + code_container.get_wrangler(queue), + target_discrs, target_association_tolerance=1e-10) .get(queue=queue)) @@ -388,11 +394,17 @@ def test_target_association_failure(ctx_getter): ) from pytential.qbx.target_assoc import ( - QBXTargetAssociator, QBXTargetAssociationFailedException) + TargetAssociationCodeContainer, associate_targets_to_qbx_centers, + QBXTargetAssociationFailedException) + + code_container = TargetAssociationCodeContainer(cl_ctx) with pytest.raises(QBXTargetAssociationFailedException): - QBXTargetAssociator(cl_ctx)(lpot_source, targets, - target_association_tolerance=1e-10) + associate_targets_to_qbx_centers( + lpot_source, + code_container.get_wrangler(queue), + targets, + target_association_tolerance=1e-10) # }}}