diff --git a/pytential/qbx/target_assoc.py b/pytential/qbx/target_assoc.py index f30dda41460292832697f57de4c563611f165004..cbf2a26c62c4936cf729b50ee96c088212d1f38f 100644 --- a/pytential/qbx/target_assoc.py +++ b/pytential/qbx/target_assoc.py @@ -194,7 +194,7 @@ QBX_CENTER_FINDER = AreaQueryElementwiseTemplate( coord_t *particles_${ax}, %endfor """, - ball_center_and_radius_expr=QBX_TREE_C_PREAMBLE + QBX_TREE_MAKO_DEFS + r""" + ball_center_and_radius_expr=QBX_TREE_C_PREAMBLE + QBX_TREE_MAKO_DEFS + r"""//CL// coord_vec_t tgt_coords; ${load_particle("INDEX_FOR_TARGET_PARTICLE(i)", "tgt_coords")} { @@ -203,31 +203,39 @@ QBX_CENTER_FINDER = AreaQueryElementwiseTemplate( ${ball_radius} = box_to_search_dist[my_box]; } """, - leaf_found_op=QBX_TREE_MAKO_DEFS + r""" - for (particle_id_t center_idx = box_to_center_starts[${leaf_box_id}]; - center_idx < box_to_center_starts[${leaf_box_id} + 1]; - ++center_idx) + leaf_found_op=QBX_TREE_MAKO_DEFS + r"""//CL// + if (target_status[i] == MARKED_QBX_CENTER_PENDING + // Found one in a prior leaf, but there may well be another + // that's closer. + || target_status[i] == MARKED_QBX_CENTER_FOUND) { - particle_id_t center = box_to_center_lists[center_idx]; - int center_side = SIDE_FOR_CENTER_PARTICLE(center); - - // Sign of side should match requested target sign. - if (center_side * target_flags[i] < 0) - { - continue; - } - - coord_vec_t center_coords; - ${load_particle("INDEX_FOR_CENTER_PARTICLE(center)", "center_coords")} - coord_t my_dist_to_center = distance(tgt_coords, center_coords); - - if (my_dist_to_center - <= expansion_radii_by_center_with_stick_out[center] - && my_dist_to_center < min_dist_to_center[i]) + for (particle_id_t center_idx = box_to_center_starts[${leaf_box_id}]; + center_idx < box_to_center_starts[${leaf_box_id} + 1]; + ++center_idx) { - target_status[i] = MARKED_QBX_CENTER_FOUND; - min_dist_to_center[i] = my_dist_to_center; - target_to_center[i] = center; + particle_id_t center = box_to_center_lists[center_idx]; + + int center_side = SIDE_FOR_CENTER_PARTICLE(center); + + // Sign of side should match requested target sign. + if (center_side * target_flags[i] < 0) + { + continue; + } + + coord_vec_t center_coords; + ${load_particle( + "INDEX_FOR_CENTER_PARTICLE(center)", "center_coords")} + coord_t my_dist_to_center = distance(tgt_coords, center_coords); + + if (my_dist_to_center + <= expansion_radii_by_center_with_stick_out[center] + && my_dist_to_center < min_dist_to_center[i]) + { + target_status[i] = MARKED_QBX_CENTER_FOUND; + min_dist_to_center[i] = my_dist_to_center; + target_to_center[i] = center; + } } } """, @@ -381,7 +389,8 @@ class QBXTargetAssociator(object): def mark_targets(self, queue, tree, peer_lists, lpot_source, target_status, debug, wait_for=None): - # Avoid generating too many kernels. + # Round up level count--this gets included in the kernel as + # a stack bound. Rounding avoids too many kernel versions. from pytools import div_ceil max_levels = 10 * div_ceil(tree.nlevels, 10) @@ -396,7 +405,7 @@ class QBXTargetAssociator(object): found_target_close_to_panel.finish() # Perform a space invader query over the sources. - source_slice = tree.user_source_ids[tree.qbx_user_source_slice] + source_slice = tree.sorted_target_ids[tree.qbx_user_source_slice] sources = [axis.with_queue(queue)[source_slice] for axis in tree.sources] tunnel_radius_by_source = \ lpot_source._close_target_tunnel_radius("nsources").with_queue(queue) @@ -472,7 +481,8 @@ class QBXTargetAssociator(object): def try_find_centers(self, queue, tree, peer_lists, lpot_source, target_status, target_flags, target_assoc, stick_out_factor, debug, wait_for=None): - # Avoid generating too many kernels. + # Round up level count--this gets included in the kernel as + # a stack bound. Rounding avoids too many kernel versions. from pytools import div_ceil max_levels = 10 * div_ceil(tree.nlevels, 10) @@ -554,7 +564,8 @@ class QBXTargetAssociator(object): def mark_panels_for_refinement(self, queue, tree, peer_lists, lpot_source, target_status, refine_flags, debug, wait_for=None): - # Avoid generating too many kernels. + # Round up level count--this gets included in the kernel as + # a stack bound. Rounding avoids too many kernel versions. from pytools import div_ceil max_levels = 10 * div_ceil(tree.nlevels, 10) diff --git a/test/test_global_qbx.py b/test/test_global_qbx.py index 7326efc13d212be30c83adecedc7a3e579722605..c5f77476b908ab7e5681784e6de73fa647ef428c 100644 --- a/test/test_global_qbx.py +++ b/test/test_global_qbx.py @@ -234,10 +234,9 @@ def test_target_association(ctx_getter, curve_name, curve_f, nelements): lpot_source, conn = QBXLayerPotentialSource(discr, order).with_refinement() del discr - from pytential.qbx.utils import get_centers_on_side - - int_centers = get_centers_on_side(lpot_source, -1) - ext_centers = get_centers_on_side(lpot_source, +1) + from pytential.qbx.utils import get_interleaved_centers + centers = np.array([ax.get(queue) + for ax in get_interleaved_centers(queue, lpot_source)]) # }}} @@ -295,59 +294,61 @@ def test_target_association(ctx_getter, curve_name, curve_f, nelements): QBXTargetAssociator(cl_ctx)(lpot_source, target_discrs) .get(queue=queue)) - expansion_radii = lpot_source._expansion_radii("nsources").get(queue) + expansion_radii = lpot_source._expansion_radii("ncenters").get(queue) - int_centers = np.array([axis.get(queue) for axis in int_centers]) - ext_centers = np.array([axis.get(queue) for axis in ext_centers]) int_targets = np.array([axis.get(queue) for axis in int_targets.nodes()]) ext_targets = np.array([axis.get(queue) for axis in ext_targets.nodes()]) # Checks that the sources match with their own centers. - def check_on_surface_targets(nsources, true_side, target_to_source_result, + def check_on_surface_targets(nsources, true_side, target_to_center, target_to_side_result): + assert (target_to_center >= 0).all() + sources = np.arange(0, nsources) - assert (target_to_source_result == sources).all() + + # Centers are on alternating sides of the geometry. Dividing by + # two yields the number of the source that spawned the center. + assert (target_to_center//2 == sources).all() + assert (target_to_side_result == true_side).all() # Checks that the targets match with centers on the appropriate side and # within the allowable distance. def check_close_targets(centers, targets, true_side, - target_to_source_result, target_to_side_result): + target_to_center, target_to_side_result, + tgt_slice): + assert (target_to_center >= 0).all() assert (target_to_side_result == true_side).all() - dists = la.norm((targets.T - centers.T[target_to_source_result]), axis=1) - assert (dists <= expansion_radii[target_to_source_result]).all() + dists = la.norm((targets.T - centers.T[target_to_center]), axis=1) + assert (dists <= expansion_radii[target_to_center]).all() - # Checks that far targets are not assigned a center. - def check_far_targets(target_to_source_result): - assert (target_to_source_result == -1).all() - - # Centers for source i are located at indices 2 * i, 2 * i + 1 - target_to_source = target_assoc.target_to_center // 2 # Center side order = -1, 1, -1, 1, ... target_to_center_side = 2 * (target_assoc.target_to_center % 2) - 1 check_on_surface_targets( nsources, -1, - target_to_source[surf_int_slice], + target_assoc.target_to_center[surf_int_slice], target_to_center_side[surf_int_slice]) check_on_surface_targets( nsources, +1, - target_to_source[surf_ext_slice], + target_assoc.target_to_center[surf_ext_slice], target_to_center_side[surf_ext_slice]) check_close_targets( - int_centers, int_targets, -1, - target_to_source[vol_int_slice], - target_to_center_side[vol_int_slice]) + centers, int_targets, -1, + target_assoc.target_to_center[vol_int_slice], + target_to_center_side[vol_int_slice], + vol_int_slice) check_close_targets( - ext_centers, ext_targets, +1, - target_to_source[vol_ext_slice], - target_to_center_side[vol_ext_slice]) + centers, ext_targets, +1, + target_assoc.target_to_center[vol_ext_slice], + target_to_center_side[vol_ext_slice], + vol_ext_slice) - check_far_targets( - target_to_source[far_slice]) + # Checks that far targets are not assigned a center. + assert (target_assoc.target_to_center[far_slice] == -1).all() # }}}