diff --git a/boxtree/area_query.py b/boxtree/area_query.py index 66ace90e028f4bc7c038425af1c9d397d0cd83fd..532520744a1eb0b61e6f0fb4dd77630734322cea 100644 --- a/boxtree/area_query.py +++ b/boxtree/area_query.py @@ -393,7 +393,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t box_id) // child_box_id lives on walk_level+1. bool a_or_o = is_adjacent_or_overlapping(root_extent, - center, level, child_center, walk_level+1, false); + center, level, child_center, walk_level+1); if (a_or_o) { @@ -422,8 +422,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t box_id) { ${load_center("next_child_center", "next_child_id")} must_be_peer &= !is_adjacent_or_overlapping(root_extent, - center, level, next_child_center, walk_level+2, - false); + center, level, next_child_center, walk_level+2); } } @@ -548,8 +547,6 @@ class AreaQueryElementwiseTemplate(object): ("peer_list_idx_dtype", peer_list_idx_dtype), ("debug", False), ("root_extent_stretch_factor", TreeBuilder.ROOT_EXTENT_STRETCH_FACTOR), - # Not used (but required by TRAVERSAL_PREAMBLE_TEMPLATE) - ("stick_out_factor", 0), ) preamble = Template( @@ -662,9 +659,7 @@ class AreaQueryBuilder(object): peer_list_idx_dtype=peer_list_idx_dtype, ball_id_dtype=ball_id_dtype, debug=False, - root_extent_stretch_factor=TreeBuilder.ROOT_EXTENT_STRETCH_FACTOR, - # Not used (but required by TRAVERSAL_PREAMBLE_TEMPLATE) - stick_out_factor=0) + root_extent_stretch_factor=TreeBuilder.ROOT_EXTENT_STRETCH_FACTOR) from pyopencl.tools import VectorArg, ScalarArg arg_decls = [ @@ -1064,8 +1059,6 @@ class PeerListFinder(object): AXIS_NAMES=AXIS_NAMES, box_flags_enum=box_flags_enum, debug=False, - # Not used (but required by TRAVERSAL_PREAMBLE_TEMPLATE) - stick_out_factor=0, # For calls to the helper is_adjacent_or_overlapping() targets_have_extent=False, sources_have_extent=False) diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 836dd2d8197edb52062eaeb4d00eba782a592733..27936c83706cff929e874cc3c047191e0a2668c4 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -135,7 +135,6 @@ typedef ${dtype_to_ctype(coord_dtype)} coord_t; typedef ${dtype_to_ctype(vec_types_dict[coord_dtype, dimensions])} coord_vec_t; #define NLEVELS ${max_levels} -#define STICK_OUT_FACTOR ((coord_t) ${stick_out_factor}) #define LEVEL_TO_RAD(level) \ (root_extent * 1 / (coord_t) (1 << (level + 1))) @@ -158,13 +157,12 @@ TRAVERSAL_PREAMBLE_TEMPLATE = ( HELPER_FUNCTION_TEMPLATE = r"""//CL// -inline bool is_adjacent_or_overlapping( +inline bool is_adjacent_or_overlapping_with_stick_out( coord_t root_extent, // target and source order only matter if include_stick_out is true. coord_vec_t target_center, int target_level, coord_vec_t source_center, int source_level, - // this is expected to be constant so that the inliner will kill the if. - const bool include_stick_out + const coord_t stick_out_factor ) { // This checks if the two boxes overlap @@ -178,18 +176,36 @@ inline bool is_adjacent_or_overlapping( coord_t rad_sum = target_rad + source_rad; coord_t slack = rad_sum + fmin(target_rad, source_rad); - if (include_stick_out) - { - slack += STICK_OUT_FACTOR * ( - 0 - %if targets_have_extent: - + target_rad - %endif - %if sources_have_extent: - + source_rad - %endif - ); - } + slack += stick_out_factor * ( + 0 + %if targets_have_extent: + + target_rad + %endif + %if sources_have_extent: + + source_rad + %endif + ); + + coord_t max_dist = 0; + %for i in range(dimensions): + max_dist = fmax(max_dist, fabs(target_center.s${i} - source_center.s${i})); + %endfor + + return max_dist <= slack; +} + + +inline bool is_adjacent_or_overlapping( + coord_t root_extent, + // note: order does not matter + coord_vec_t target_center, int target_level, + coord_vec_t source_center, int source_level) { + // This checks if the two boxes overlap. + + coord_t target_rad = LEVEL_TO_RAD(target_level); + coord_t source_rad = LEVEL_TO_RAD(source_level); + coord_t rad_sum = target_rad + source_rad; + coord_t slack = rad_sum + fmin(target_rad, source_rad); coord_t max_dist = 0; %for i in range(dimensions): @@ -299,7 +315,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t box_id) ${load_center("child_center", "child_box_id")} bool a_or_o = is_adjacent_or_overlapping(root_extent, - center, level, child_center, box_levels[child_box_id], false); + center, level, child_center, box_levels[child_box_id]); if (a_or_o) { @@ -380,7 +396,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t target_box_number) ${load_center("child_center", "child_box_id")} bool a_or_o = is_adjacent_or_overlapping(root_extent, - center, level, child_center, box_levels[child_box_id], false); + center, level, child_center, box_levels[child_box_id]); if (a_or_o) { @@ -451,7 +467,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) ${load_center("sib_center", "sib_box_id")} bool sep = !is_adjacent_or_overlapping(root_extent, - center, level, sib_center, box_levels[sib_box_id], false); + center, level, sib_center, box_levels[sib_box_id]); if (sep) { @@ -527,7 +543,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t target_box_number) int child_level = box_levels[child_box_id]; bool a_or_o = is_adjacent_or_overlapping(root_extent, - center, level, child_center, child_level, false); + center, level, child_center, child_level); if (a_or_o) { @@ -549,9 +565,9 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t target_box_number) { %if sources_have_extent or targets_have_extent: const bool a_or_o_with_stick_out = - is_adjacent_or_overlapping(root_extent, + is_adjacent_or_overlapping_with_stick_out(root_extent, center, level, child_center, - child_level, true); + child_level, stick_out_factor); %else: const bool a_or_o_with_stick_out = false; %endif @@ -714,7 +730,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) { ${load_center("colleague_center", "colleague_box_id")} bool a_or_o = is_adjacent_or_overlapping(root_extent, - center, box_level, colleague_center, walk_level, false); + center, box_level, colleague_center, walk_level); if (!a_or_o) { @@ -722,9 +738,9 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) %if sources_have_extent or targets_have_extent: const bool a_or_o_with_stick_out = - is_adjacent_or_overlapping(root_extent, + is_adjacent_or_overlapping_with_stick_out(root_extent, center, box_level, colleague_center, - walk_level, true); + walk_level, stick_out_factor); if (a_or_o_with_stick_out) { @@ -741,9 +757,9 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) %endif { bool parent_a_or_o_with_stick_out = - is_adjacent_or_overlapping(root_extent, + is_adjacent_or_overlapping_with_stick_out(root_extent, parent_center, box_level-1, colleague_center, - walk_level, true); + walk_level, stick_out_factor); if (parent_a_or_o_with_stick_out) { @@ -1148,8 +1164,7 @@ class FMMTraversalBuilder: @memoize_method def get_kernel_info(self, dimensions, particle_id_dtype, box_id_dtype, coord_dtype, box_level_dtype, max_levels, - sources_are_targets, sources_have_extent, targets_have_extent, - stick_out_factor): + sources_are_targets, sources_have_extent, targets_have_extent): logger.info("traversal build kernels: start build") @@ -1171,7 +1186,6 @@ class FMMTraversalBuilder: sources_are_targets=sources_are_targets, sources_have_extent=sources_have_extent, targets_have_extent=targets_have_extent, - stick_out_factor=stick_out_factor, ) from pyopencl.algorithm import ListOfListsBuilder from pyopencl.tools import VectorArg, ScalarArg @@ -1238,6 +1252,7 @@ class FMMTraversalBuilder: ], []), ("sep_smaller", SEP_SMALLER_TEMPLATE, [ + ScalarArg(coord_dtype, "stick_out_factor"), VectorArg(box_id_dtype, "target_boxes"), VectorArg(box_id_dtype, "colleagues_starts"), VectorArg(box_id_dtype, "colleagues_list"), @@ -1248,6 +1263,7 @@ class FMMTraversalBuilder: else []), ("sep_bigger", SEP_BIGGER_TEMPLATE, [ + ScalarArg(coord_dtype, "stick_out_factor"), VectorArg(box_id_dtype, "target_or_target_parent_boxes"), VectorArg(box_id_dtype, "box_parent_ids"), VectorArg(box_id_dtype, "colleagues_starts"), @@ -1307,8 +1323,7 @@ class FMMTraversalBuilder: tree.dimensions, tree.particle_id_dtype, tree.box_id_dtype, tree.coord_dtype, tree.box_level_dtype, max_levels, tree.sources_are_targets, - tree.sources_have_extent, tree.targets_have_extent, - tree.stick_out_factor) + tree.sources_have_extent, tree.targets_have_extent) def fin_debug(s): if debug: @@ -1439,7 +1454,7 @@ class FMMTraversalBuilder: queue, len(target_boxes), tree.box_centers.data, tree.root_extent, tree.box_levels.data, tree.aligned_nboxes, tree.box_child_ids.data, tree.box_flags.data, - target_boxes.data, + tree.stick_out_factor, target_boxes.data, colleagues.starts.data, colleagues.lists.data) wait_for = [] @@ -1480,7 +1495,8 @@ class FMMTraversalBuilder: queue, len(target_or_target_parent_boxes), tree.box_centers.data, tree.root_extent, tree.box_levels.data, tree.aligned_nboxes, tree.box_child_ids.data, tree.box_flags.data, - target_or_target_parent_boxes.data, tree.box_parent_ids.data, + tree.stick_out_factor, target_or_target_parent_boxes.data, + tree.box_parent_ids.data, colleagues.starts.data, colleagues.lists.data, wait_for=wait_for) wait_for = [evt] sep_bigger = result["sep_bigger"] diff --git a/boxtree/tree_build.py b/boxtree/tree_build.py index 8c6df0dc1f3d47e286471abcb6904552b64f64fb..f9d116f88ce554d5173a25b3aceff5d9d8d41f4b 100644 --- a/boxtree/tree_build.py +++ b/boxtree/tree_build.py @@ -62,13 +62,13 @@ class TreeBuilder(object): def get_kernel_info(self, dimensions, coord_dtype, particle_id_dtype, box_id_dtype, sources_are_targets, srcntgts_have_extent, - stick_out_factor, kind): + kind): from boxtree.tree_build_kernels import get_tree_build_kernel_info return get_tree_build_kernel_info(self.context, dimensions, coord_dtype, particle_id_dtype, box_id_dtype, sources_are_targets, srcntgts_have_extent, - stick_out_factor, self.morton_nr_dtype, self.box_level_dtype, + self.morton_nr_dtype, self.box_level_dtype, kind=kind) # {{{ run control @@ -186,7 +186,7 @@ class TreeBuilder(object): knl_info = self.get_kernel_info(dimensions, coord_dtype, particle_id_dtype, box_id_dtype, sources_are_targets, srcntgts_have_extent, - stick_out_factor, kind=kind) + kind=kind) logger.info("tree build: start") @@ -535,9 +535,13 @@ class TreeBuilder(object): fin_debug("morton count scan") + morton_count_args = common_args + if srcntgts_have_extent: + morton_count_args += (stick_out_factor,) + # writes: box_morton_bin_counts evt = knl_info.morton_count_scan( - *common_args, queue=queue, size=nsrcntgts, + *morton_count_args, queue=queue, size=nsrcntgts, wait_for=wait_for) wait_for = [evt] diff --git a/boxtree/tree_build_kernels.py b/boxtree/tree_build_kernels.py index 7dfa08b86970454746537dfd77a0fd0039b08c1c..1a8a592ac49c24e8e4a54f8f987d954eeda3a75c 100644 --- a/boxtree/tree_build_kernels.py +++ b/boxtree/tree_build_kernels.py @@ -184,7 +184,6 @@ TYPE_DECL_PREAMBLE_TPL = Template(r"""//CL// """, strict_undefined=True) GENERIC_PREAMBLE_TPL = Template(r"""//CL// - #define STICK_OUT_FACTOR ((coord_t) ${stick_out_factor}) // Use this as dbg_printf(("oh snap: %d\n", stuff)); Note the double // parentheses. @@ -291,6 +290,7 @@ MORTON_NR_SCAN_PREAMBLE_TPL = Template(r"""//CL// %endfor %if srcntgts_have_extent: , global const coord_t *srcntgt_radii + , const coord_t stick_out_factor %endif ) { @@ -306,13 +306,22 @@ MORTON_NR_SCAN_PREAMBLE_TPL = Template(r"""//CL// coord_t srcntgt_radius = srcntgt_radii[user_srcntgt_id]; %endif + %if not srcntgts_have_extent: + // This argument is only supplied with srcntgts_have_extent. + #define stick_out_factor 0. + %endif + const coord_t one_half = ((coord_t) 1) / 2; const coord_t box_radius_factor = // AMD CPU seems to like to miscompile this--change with care. // (last seen on 13.4-2) - (1. + STICK_OUT_FACTOR) + (1. + stick_out_factor) * one_half; // convert diameter to radius + %if not srcntgts_have_extent: + #undef stick_out_factor + %endif + %for ax in axis_names: // Most FMMs are isotropic, i.e. global_extent_{x,y,z} are all the same. // Nonetheless, the gain from exploiting this assumption seems so @@ -339,7 +348,7 @@ MORTON_NR_SCAN_PREAMBLE_TPL = Template(r"""//CL// * (1U << (1 + particle_level))); %if srcntgts_have_extent: - // Need to compute center to compare excess with STICK_OUT_FACTOR. + // Need to compute center to compare excess with stick_out_factor. coord_t next_level_box_center_${ax} = global_min_${ax} + global_extent_${ax} @@ -801,8 +810,7 @@ LEVEL_RESTRICT_TPL = Template( ${my_load_center("box_center", "box_id")} ${my_load_center("child_center", "child_box_id")} is_adjacent = is_adjacent_or_overlapping( - root_extent, child_center, child_level, box_center, level, - false); + root_extent, child_center, child_level, box_center, level); } if (is_adjacent) @@ -1238,8 +1246,7 @@ BOX_INFO_KERNEL_TPL = ElementwiseTemplate( def get_tree_build_kernel_info(context, dimensions, coord_dtype, particle_id_dtype, box_id_dtype, sources_are_targets, srcntgts_have_extent, - stick_out_factor, morton_nr_dtype, box_level_dtype, - kind): + morton_nr_dtype, box_level_dtype, kind): level_restrict = (kind == "adaptive-level-restricted") adaptive = not (kind == "non-adaptive") @@ -1297,8 +1304,6 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, sources_are_targets=sources_are_targets, srcntgts_have_extent=srcntgts_have_extent, - stick_out_factor=stick_out_factor, - enable_assert=False, enable_printf=False, ) @@ -1378,10 +1383,17 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, if srcntgts_have_extent else []) ) + morton_count_scan_arguments = list(common_arguments) + + if srcntgts_have_extent: + morton_count_scan_arguments += [ + (ScalarArg(coord_dtype, "stick_out_factor")) + ] + from pyopencl.scan import GenericScanKernel morton_count_scan = GenericScanKernel( context, morton_bin_count_dtype, - arguments=common_arguments, + arguments=morton_count_scan_arguments, input_expr=( "scan_t_from_particle(%s)" % ", ".join([ @@ -1390,7 +1402,8 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, "refine_weights", ] + ["%s" % ax for ax in axis_names] - + (["srcntgt_radii"] if srcntgts_have_extent else []))), + + (["srcntgt_radii, stick_out_factor"] + if srcntgts_have_extent else []))), scan_expr="scan_t_add(a, b, across_seg_boundary)", neutral="scan_t_neutral()", is_segment_start_expr="box_start_flags[i]",