diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 4377f931132c63487e3bb636100fee000124a7bd..a35c60d08cfcc9850b49f50fd7fd4b561c1019d4 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -105,11 +105,40 @@ TRAVERSAL_PREAMBLE_MAKO_DEFS = r"""//CL:mako// <%def name="load_center(name, box_id, declare=True)"> %if declare: - coord_vec_t ${name}; + coord_vec_t ${name} = (coord_vec_t)( + %else: + ${name} = (coord_vec_t)( %endif - %for i in range(dimensions): - ${name}.${AXIS_NAMES[i]} = box_centers[aligned_nboxes * ${i} + ${box_id}]; - %endfor + %for i in range(dimensions): + box_centers[aligned_nboxes * ${i} + ${box_id}] + %if i + 1 < dimensions: + , + %endif + %endfor + ); + + +<%def name="load_true_box_extent(name, box_id, kind, declare=True)"> + %if declare: + coord_vec_t ${name}_ext_center, ${name}_radii_vec; + %endif + + { + %for bound in ["min", "max"]: + coord_vec_t ${name}_${bound} = (coord_vec_t)( + %for iaxis in range(dimensions): + box_${kind}_bounding_box_${bound}[ + ${iaxis} * aligned_nboxes + ${box_id}] + %if iaxis + 1 < dimensions: + , + %endif + %endfor + ); + %endfor + + ${name}_ext_center = 0.5*(${name}_min + ${name}_max); + ${name}_radii_vec = 0.5*(${name}_max - ${name}_min); + } <%def name="check_l_infty_ball_overlap( @@ -117,15 +146,12 @@ TRAVERSAL_PREAMBLE_MAKO_DEFS = r"""//CL:mako// { ${load_center("box_center", box_id)} int box_level = box_levels[${box_id}]; - coord_t size_sum = LEVEL_TO_RAD(box_level) + ${ball_radius}; - coord_t max_dist = 0; %for i in range(dimensions): max_dist = fmax(max_dist, fabs(${ball_center}.s${i} - box_center.s${i})); %endfor - ${is_overlapping} = max_dist <= size_sum; } @@ -157,6 +183,8 @@ typedef ${dtype_to_ctype(vec_types_dict[coord_dtype, dimensions])} coord_vec_t; %else: #define dbg_printf(ARGS) /* */ %endif + +#define square(x) ((x)*(x)) """ @@ -298,6 +326,98 @@ LEVEL_START_BOX_NR_EXTRACTOR_TEMPLATE = ElementwiseTemplate( # }}} +# {{{ box extents + +BOX_EXTENTS_FINDER_TEMPLATE = ElementwiseTemplate( + arguments="""//CL:mako// + box_id_t aligned_nboxes, + box_id_t *box_child_ids, + coord_t *box_centers, + particle_id_t *box_particle_starts, + particle_id_t *box_particle_counts_nonchild + + %for iaxis in range(dimensions): + , const coord_t *particle_${AXIS_NAMES[iaxis]} + %endfor + , + const coord_t *particle_radii, + int enable_radii, + + coord_t *box_particle_bounding_box_min, + coord_t *box_particle_bounding_box_max, + """, + + operation=TRAVERSAL_PREAMBLE_MAKO_DEFS + r"""//CL:mako// + box_id_t ibox = i; + + ${load_center("box_center", "ibox")} + + <% axis_names = AXIS_NAMES[:dimensions] %> + + // incorporate own particles + %for iaxis, ax in enumerate(axis_names): + coord_t min_particle_${ax} = box_center.s${iaxis}; + coord_t max_particle_${ax} = box_center.s${iaxis}; + %endfor + + particle_id_t start = box_particle_starts[ibox]; + particle_id_t stop = start + box_particle_counts_nonchild[ibox]; + + for (particle_id_t iparticle = start; iparticle < stop; ++iparticle) + { + coord_t particle_rad = 0; + %if sources_have_extent or targets_have_extent: + // If only one has extent, then the radius array for the other + // may well be a null pointer. + if (enable_radii) + particle_rad = particle_radii[iparticle]; + %endif + + %for iaxis, ax in enumerate(axis_names): + coord_t particle_coord_${ax} = particle_${ax}[iparticle]; + + min_particle_${ax} = min( + min_particle_${ax}, + particle_coord_${ax} - particle_rad); + max_particle_${ax} = max( + max_particle_${ax}, + particle_coord_${ax} + particle_rad); + %endfor + } + + // incorporate child boxes + for (int morton_nr = 0; morton_nr < ${2**dimensions}; ++morton_nr) + { + box_id_t child_id = box_child_ids[ + morton_nr * aligned_nboxes + ibox]; + + if (child_id == 0) + continue; + + %for iaxis, ax in enumerate(axis_names): + min_particle_${ax} = min( + min_particle_${ax}, + box_particle_bounding_box_min[ + ${iaxis} * aligned_nboxes + child_id]); + max_particle_${ax} = max( + max_particle_${ax}, + box_particle_bounding_box_max[ + ${iaxis} * aligned_nboxes + child_id]); + %endfor + } + + // write result + %for iaxis, ax in enumerate(axis_names): + box_particle_bounding_box_min[ + ${iaxis} * aligned_nboxes + ibox] = min_particle_${ax}; + box_particle_bounding_box_max[ + ${iaxis} * aligned_nboxes + ibox] = max_particle_${ax}; + %endfor + """, + name="find_box_extents") + +# }}} + # {{{ same-level non-well-separated boxes (generalization of "colleagues") SAME_LEVEL_NON_WELL_SEP_BOXES_TEMPLATE = r"""//CL// @@ -485,6 +605,9 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) box_id_t sib_box_id = box_child_ids[ morton_nr * aligned_nboxes + parent_nf]; + if (sib_box_id == 0) + continue; + ${load_center("sib_center", "sib_box_id")} bool sep = !is_adjacent_or_overlapping_with_neighborhood( @@ -508,49 +631,38 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) FROM_SEP_SMALLER_TEMPLATE = r"""//CL// -inline bool meets_sep_smaller_criterion( - coord_t root_extent, - coord_vec_t target_center, int target_level, - coord_vec_t source_center, int source_level, - coord_t stick_out_factor) -{ - coord_t target_rad = LEVEL_TO_RAD(target_level); - coord_t source_rad = LEVEL_TO_RAD(source_level); - coord_t min_allowed_center_l_inf_dist = ( - 3 * source_rad - + (1 + stick_out_factor) * target_rad); - - coord_t l_inf_dist = 0; - %for i in range(dimensions): - l_inf_dist = fmax( - l_inf_dist, - fabs(target_center.s${i} - source_center.s${i})); - %endfor - - return l_inf_dist >= min_allowed_center_l_inf_dist * (1 - 8 * COORD_T_MACH_EPS); -} - - void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t target_box_number) { // /!\ target_box_number is *not* a box_id, despite the type. // It's the number of the target box we're currently processing. - box_id_t box_id = target_boxes[target_box_number]; + box_id_t tgt_box_id = target_boxes[target_box_number]; - ${load_center("center", "box_id")} + ${load_center("tgt_center", "tgt_box_id")} - int level = box_levels[box_id]; + int tgt_level = box_levels[tgt_box_id]; + + %if targets_have_extent: + %if from_sep_smaller_crit in ["static_linf", "static_l2"]: + coord_t tgt_stickout_l_inf_rad = + (1 + stick_out_factor) * LEVEL_TO_RAD(tgt_level); + + %elif from_sep_smaller_crit == "precise_linf": + ${load_true_box_extent("tgt", "tgt_box_id", "target")} + // defines tgt_ext_center, tgt_radii_vec - box_id_t slnws_start = same_level_non_well_sep_boxes_starts[box_id]; - box_id_t slnws_stop = same_level_non_well_sep_boxes_starts[box_id+1]; + %endif + %endif + + box_id_t slnws_start = same_level_non_well_sep_boxes_starts[tgt_box_id]; + box_id_t slnws_stop = same_level_non_well_sep_boxes_starts[tgt_box_id+1]; // /!\ i is not a box_id, it's an index into same_level_non_well_sep_boxes_lists. for (box_id_t i = slnws_start; i < slnws_stop; ++i) { box_id_t same_lev_nws_box = same_level_non_well_sep_boxes_lists[i]; - if (same_lev_nws_box == box_id) + if (same_lev_nws_box == tgt_box_id) continue; // Colleagues (same-level NWS boxes) for 1-away are always adjacent, so @@ -563,21 +675,21 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t target_box_number) while (continue_walk) { // Loop invariant: - // walk_parent_box_id is, at first, always adjacent to box_id. + // walk_parent_box_id is, at first, always adjacent to tgt_box_id. // // This is true at the first level because colleagues are adjacent // by definition, and is kept true throughout the walk by only descending // into adjacent boxes. // // As we descend, we may find a child of an adjacent box that is - // non-adjacent to box_id. + // non-adjacent to tgt_box_id. // // If neither sources nor targets have extent, then that - // nonadjacent child box is added to box_id's from_sep_smaller ("list 3 - // far") and that's it. + // nonadjacent child box is added to tgt_box_id's from_sep_smaller + // ("list 3far") and that's it. // // If they have extent, then while they may be separated, the - // intersection of box_id's and the child box's stick-out region + // intersection of tgt_box_id's and the child box's stick-out region // may be non-empty, and we thus need to add that child to // from_sep_close_smaller ("list 3 close") for the interaction to be // done by direct evaluation. We also need to descend into that @@ -599,7 +711,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t target_box_number) int walk_level = box_levels[walk_box_id]; bool in_list_1 = is_adjacent_or_overlapping(root_extent, - center, level, walk_center, walk_level); + tgt_center, tgt_level, walk_center, walk_level); if (in_list_1) { @@ -619,21 +731,88 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t target_box_number) } else { - %if sources_have_extent or targets_have_extent: - const bool meets_crit = - meets_sep_smaller_criterion(root_extent, - center, level, - walk_center, walk_level, - stick_out_factor); + bool meets_sep_crit; + + <% assert not sources_have_extent %> + + %if not targets_have_extent: + meets_sep_crit = true; + + %elif from_sep_smaller_crit == "static_linf": + { + coord_t source_rad = LEVEL_TO_RAD(walk_level); + + // l^infty distance between source box and target box. + // Negative indicates overlap. + coord_t l_inf_dist = 0; + %for i in range(dimensions): + l_inf_dist = fmax( + l_inf_dist, + fabs(tgt_center.s${i} - walk_center.s${i}) + - tgt_stickout_l_inf_rad + - source_rad); + %endfor + + meets_sep_crit = l_inf_dist >= + (2 - 8 * COORD_T_MACH_EPS) * source_rad; + } + + %elif from_sep_smaller_crit == "precise_linf": + { + coord_t source_rad = LEVEL_TO_RAD(walk_level); + + // l^infty distance between source box and target box. + // Negative indicates overlap. + coord_t l_inf_dist = 0; + %for i in range(dimensions): + l_inf_dist = fmax( + l_inf_dist, + fabs(tgt_ext_center.s${i} - walk_center.s${i}) + - tgt_radii_vec.s${i} + - source_rad); + %endfor + + meets_sep_crit = l_inf_dist >= + (2 - 8 * COORD_T_MACH_EPS) * source_rad; + } + + %elif from_sep_smaller_crit == "static_l2": + { + coord_t source_l_inf_rad = LEVEL_TO_RAD(walk_level); + + // l^2 distance between source box and target centers. + coord_t l_2_squared_center_dist = + 0 + %for i in range(dimensions): + + square(tgt_center.s${i} - walk_center.s${i}) + %endfor + ; + + // l^2 distance between source box and target box. + // Negative indicates overlap. + coord_t l_2_box_dist = + sqrt(l_2_squared_center_dist) + - sqrt((coord_t) (${dimensions})) + * tgt_stickout_l_inf_rad + - source_l_inf_rad; + + meets_sep_crit = l_2_box_dist >= + (2 - 8 * COORD_T_MACH_EPS) * source_l_inf_rad; + } + %else: - const bool meets_crit = true; + <% raise ValueError( + "unknown value of from_sep_smaller_crit: %s" + % from_sep_smaller_crit) %> %endif // We're no longer *immediately* adjacent to our target // box, but our stick-out regions might still have a // non-empty intersection. - if (meets_crit) + if (meets_sep_crit + && box_source_counts_cumul[walk_box_id] + >= from_sep_smaller_min_nsources_cumul) { if (from_sep_smaller_source_level == walk_level) APPEND_from_sep_smaller(walk_box_id); @@ -756,9 +935,9 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) if (tgt_box_level == 0) return; - box_id_t parent_box_id = box_parent_ids[tgt_ibox]; - const int parent_level = tgt_box_level - 1; - ${load_center("parent_center", "parent_box_id")} + box_id_t tgt_parent_box_id = box_parent_ids[tgt_ibox]; + const int tgt_parent_level = tgt_box_level - 1; + ${load_center("parent_center", "tgt_parent_box_id")} box_flags_t tgt_box_flags = box_flags[tgt_ibox]; @@ -768,13 +947,13 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) // may directly jump to the parent level. int walk_level = tgt_box_level - 1; - box_id_t current_parent_box_id = parent_box_id; + box_id_t current_tgt_parent_box_id = tgt_parent_box_id; %else: // In a 2+-away FMM, tgt_ibox's same-level non-well-separated boxes *may* // be sufficiently separated from tgt_ibox to be in its list 4. int walk_level = tgt_box_level; - box_id_t current_parent_box_id = tgt_ibox; + box_id_t current_tgt_parent_box_id = tgt_ibox; %endif /* @@ -788,14 +967,14 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) for (; walk_level != 0; // {{{ advance --walk_level, - current_parent_box_id = box_parent_ids[current_parent_box_id] + current_tgt_parent_box_id = box_parent_ids[current_tgt_parent_box_id] // }}} ) { box_id_t slnws_start = - same_level_non_well_sep_boxes_starts[current_parent_box_id]; + same_level_non_well_sep_boxes_starts[current_tgt_parent_box_id]; box_id_t slnws_stop = - same_level_non_well_sep_boxes_starts[current_parent_box_id+1]; + same_level_non_well_sep_boxes_starts[current_tgt_parent_box_id+1]; // /!\ i is not a box id, it's an index into // same_level_non_well_sep_boxes_lists. @@ -844,7 +1023,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) { bool in_parent_list_1 = is_adjacent_or_overlapping(root_extent, - parent_center, parent_level, + parent_center, tgt_parent_level, slnws_center, walk_level); bool would_be_in_parent_list_4_not_considering_stickout = ( @@ -892,7 +1071,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t itarget_or_target_parent_box) %if sources_have_extent or targets_have_extent: const bool parent_meets_with_ext_sep_criterion = meets_sep_bigger_criterion(root_extent, - parent_center, parent_level, + parent_center, tgt_parent_level, slnws_center, walk_level, stick_out_factor); @@ -1013,6 +1192,31 @@ class FMMTraversalInfo(DeviceDataRecord): Indices into :attr:`target_or_target_parent_boxes` indicating where each level starts and ends. + .. ------------------------------------------------------------------------ + .. rubric:: Box extents + .. ------------------------------------------------------------------------ + + The attributes in this section are only available if the respective + particle type (source/target) has extents. + + If they are not available, the corresponding attributes will be *None*. + + .. attribute:: box_source_bounding_box_min + + ``coordt_t [dimensions, aligned_nboxes]`` + + .. attribute:: box_source_bounding_box_max + + ``coordt_t [dimensions, aligned_nboxes]`` + + .. attribute:: box_target_bounding_box_min + + ``coordt_t [dimensions, aligned_nboxes]`` + + .. attribute:: box_target_bounding_box_max + + ``coordt_t [dimensions, aligned_nboxes]`` + .. ------------------------------------------------------------------------ .. rubric:: Same-level non-well-separated boxes .. ------------------------------------------------------------------------ @@ -1308,16 +1512,70 @@ class _KernelInfo(Record): class FMMTraversalBuilder: - def __init__(self, context, well_sep_is_n_away=1): + def __init__(self, context, well_sep_is_n_away=1, from_sep_smaller_crit=None): + """ + :arg well_sep_is_n_away: Either An integer 1 or greater. (Only 2 is tested) + The spacing between boxes that is considered "well-separated" for + :attr:`from_sep_siblings` (List 2). + :arg from_sep_smaller_crit: The criterion used to determine separation + box dimensions and separation for :attr:`from_sep_smaller_by_level` + (List 3). May be one of ``"static_linf"`` (use the box square, + possibly enlarged by :attr:`Tree.stick_out_factor`), ``"precise_linf"` + (use the precise extent of targets in the box, including their radii), + or ``"static_l2"`` (use the circumcircle of the box, + possibly enlarged by :attr:`Tree.stick_out_factor`). + """ self.context = context self.well_sep_is_n_away = well_sep_is_n_away + self.from_sep_smaller_crit = from_sep_smaller_crit # {{{ kernel builder @memoize_method def get_kernel_info(self, dimensions, particle_id_dtype, box_id_dtype, coord_dtype, box_level_dtype, max_levels, - sources_are_targets, sources_have_extent, targets_have_extent): + sources_are_targets, sources_have_extent, targets_have_extent, + extent_norm): + + # {{{ process from_sep_smaller_crit + + from_sep_smaller_crit = self.from_sep_smaller_crit + + if from_sep_smaller_crit is None: + from_sep_smaller_crit = "precise_linf" + + if extent_norm == "linf": + # no special checks needed + pass + + elif extent_norm == "l2": + if from_sep_smaller_crit == "static_linf": + # Not technically necessary, but static linf will assume box + # bounds that are not guaranteed to contain all particle + # extents. + raise ValueError( + "The static l^inf from-sep-smaller criterion " + "cannot be used with the l^2 extent norm") + + elif extent_norm is None: + assert not (sources_have_extent or targets_have_extent) + + if from_sep_smaller_crit is None: + # doesn't matter + from_sep_smaller_crit = "static_linf" + + else: + raise ValueError("unexpected value of 'extent_norm': %s" + % extent_norm) + + if from_sep_smaller_crit not in [ + "static_linf", "precise_linf", + "static_l2", + ]: + raise ValueError("unexpected value of 'from_sep_smaller_crit': %s" + % from_sep_smaller_crit) + + # }}} logger.info("traversal build kernels: start build") @@ -1341,6 +1599,7 @@ class FMMTraversalBuilder: sources_have_extent=sources_have_extent, targets_have_extent=targets_have_extent, well_sep_is_n_away=self.well_sep_is_n_away, + from_sep_smaller_crit=from_sep_smaller_crit, ) from pyopencl.algorithm import ListOfListsBuilder from pyopencl.tools import VectorArg, ScalarArg @@ -1371,6 +1630,23 @@ class FMMTraversalBuilder: debug=debug, name_prefix="sources_parents_and_targets") + result["box_extents_finder"] = \ + BOX_EXTENTS_FINDER_TEMPLATE.build(self.context, + type_aliases=( + ("box_id_t", box_id_dtype), + ("coord_t", coord_dtype), + ("coord_vec_t", cl.cltypes.vec_types[ + coord_dtype, dimensions]), + ("particle_id_t", particle_id_dtype), + ), + var_values=( + ("dimensions", dimensions), + ("AXIS_NAMES", AXIS_NAMES), + ("sources_have_extent", sources_have_extent), + ("targets_have_extent", targets_have_extent), + ), + ) + result["level_start_box_nrs_extractor"] = \ LEVEL_START_BOX_NR_EXTRACTOR_TEMPLATE.build(self.context, type_aliases=( @@ -1416,6 +1692,11 @@ class FMMTraversalBuilder: "same_level_non_well_sep_boxes_starts"), VectorArg(box_id_dtype, "same_level_non_well_sep_boxes_lists"), + VectorArg(coord_dtype, "box_target_bounding_box_min"), + VectorArg(coord_dtype, "box_target_bounding_box_max"), + VectorArg(particle_id_dtype, "box_source_counts_cumul"), + ScalarArg(particle_id_dtype, + "from_sep_smaller_min_nsources_cumul"), ScalarArg(box_id_dtype, "from_sep_smaller_source_level"), ], ["from_sep_close_smaller"] @@ -1460,7 +1741,8 @@ class FMMTraversalBuilder: # {{{ driver - def __call__(self, queue, tree, wait_for=None, debug=False): + def __call__(self, queue, tree, wait_for=None, debug=False, + _from_sep_smaller_min_nsources_cumul=None): """ :arg queue: A :class:`pyopencl.CommandQueue` instance. :arg tree: A :class:`boxtree.Tree` instance. @@ -1472,6 +1754,10 @@ class FMMTraversalBuilder: for dependency management. """ + if _from_sep_smaller_min_nsources_cumul is None: + # default to old no-threshold behavior + _from_sep_smaller_min_nsources_cumul = 0 + if not tree._is_pruned: raise ValueError("tree must be pruned for traversal generation") @@ -1490,7 +1776,8 @@ class FMMTraversalBuilder: tree.dimensions, tree.particle_id_dtype, tree.box_id_dtype, tree.coord_dtype, tree.box_level_dtype, max_levels, tree.sources_are_targets, - tree.sources_have_extent, tree.targets_have_extent) + tree.sources_have_extent, tree.targets_have_extent, + tree.extent_norm) def fin_debug(s): if debug: @@ -1567,7 +1854,91 @@ class FMMTraversalBuilder: # }}} - # {{{ same-level near-field + # {{{ box extents + + fin_debug("finding box extents") + + box_source_bounding_box_min = cl.array.empty( + queue, (tree.dimensions, tree.aligned_nboxes), + dtype=tree.coord_dtype) + box_source_bounding_box_max = cl.array.empty( + queue, (tree.dimensions, tree.aligned_nboxes), + dtype=tree.coord_dtype) + + if tree.sources_are_targets: + box_target_bounding_box_min = box_source_bounding_box_min + box_target_bounding_box_max = box_source_bounding_box_max + else: + box_target_bounding_box_min = cl.array.empty( + queue, (tree.dimensions, tree.aligned_nboxes), + dtype=tree.coord_dtype) + box_target_bounding_box_max = cl.array.empty( + queue, (tree.dimensions, tree.aligned_nboxes), + dtype=tree.coord_dtype) + + bogus_radii_array = cl.array.empty(queue, 1, dtype=tree.coord_dtype) + + # nlevels-1 is the highest valid level index + for level in range(tree.nlevels-1, -1, -1): + start, stop = tree.level_start_box_nrs[level:level+2] + + for (skip, enable_radii, bbox_min, bbox_max, + pstarts, pcounts, radii_tree_attr, particles) in [ + ( + # never skip + False, + + tree.sources_have_extent, + box_source_bounding_box_min, + box_source_bounding_box_max, + tree.box_source_starts, + tree.box_source_counts_nonchild, + "source_radii", + tree.sources), + ( + # skip the 'target' round if sources and targets + # are the same. + tree.sources_are_targets, + + tree.targets_have_extent, + box_target_bounding_box_min, + box_target_bounding_box_max, + tree.box_target_starts, + tree.box_target_counts_nonchild, + "target_radii", + tree.targets), + ]: + + if skip: + continue + + args = ( + ( + tree.aligned_nboxes, + tree.box_child_ids, + tree.box_centers, + pstarts, pcounts,) + + tuple(particles) + + ( + getattr(tree, radii_tree_attr, bogus_radii_array), + enable_radii, + + bbox_min, + bbox_max)) + + evt = knl_info.box_extents_finder( + *args, + + range=slice(start, stop), + queue=queue, wait_for=wait_for) + + wait_for = [evt] + + del bogus_radii_array + + # }}} + + # {{{ same-level non-well-separated boxes # If well_sep_is_n_away is 1, this agrees with the definition of # 'colleagues' from the classical FMM literature. @@ -1629,6 +2000,10 @@ class FMMTraversalBuilder: tree.stick_out_factor, target_boxes.data, same_level_non_well_sep_boxes.starts.data, same_level_non_well_sep_boxes.lists.data, + box_target_bounding_box_min.data, + box_target_bounding_box_max.data, + tree.box_source_counts_cumul.data, + _from_sep_smaller_min_nsources_cumul, ) from_sep_smaller_wait_for = [] @@ -1718,6 +2093,11 @@ class FMMTraversalBuilder: level_start_target_or_target_parent_box_nrs=( level_start_target_or_target_parent_box_nrs), + box_source_bounding_box_min=box_source_bounding_box_min, + box_source_bounding_box_max=box_source_bounding_box_max, + box_target_bounding_box_min=box_target_bounding_box_min, + box_target_bounding_box_max=box_target_bounding_box_max, + same_level_non_well_sep_boxes_starts=( same_level_non_well_sep_boxes.starts), same_level_non_well_sep_boxes_lists=( diff --git a/boxtree/tree.py b/boxtree/tree.py index 84b0a94f83fb12f29fc47e048a3c7e9ea96a0bd0..518d61ace61d07492047e7ef9e0a4feed2b74039 100644 --- a/boxtree/tree.py +++ b/boxtree/tree.py @@ -56,7 +56,7 @@ class box_flags_enum(Enum): # noqa # {{{ tree data structure class Tree(DeviceDataRecord): - """A quad/octree consisting of particles sorted into a hierarchy of boxes. + r"""A quad/octree consisting of particles sorted into a hierarchy of boxes. Optionally, particles may be designated 'sources' and 'targets'. They may also be assigned radii which restrict the minimum size of the box into which they may be sorted. @@ -106,9 +106,29 @@ class Tree(DeviceDataRecord): .. attribute:: stick_out_factor - The fraction of the (:math:`l^\infty`) box radius by which the - :math:`l^\infty` circles given by :attr:`source_radii` may stick out - the box in which they are contained. A scalar. + A scalar used for calculating how much particles with extent may + overextend their containing box. + + Each box in the tree can be thought of as being surrounded by a + fictitious box whose :math:`l^\infty` radius is `1 + stick_out_factor` + larger. Particles with extent are allowed to extend inside (a) the + ficitious box or (b) a disk surrounding the fictious box, depending on + :attr:`extent_norm`. + + .. attribute:: extent_norm + + One of ``None``, ``"l2"`` or ``"linf"``. If *None*, particles do not have + extent. If not *None*, indicates the norm with which extent-bearing particles + are determined to lie 'inside' a box, taking into account the box's + :attr:`stick_out_factor`. + + This image illustrates the difference in semantics: + + .. image:: images/linf-l2.png + + In the figure, the box has (:math:`\ell^\infty`) radius :math:`R`, the + particle has radius :math:`r`, and :attr:`stick_out_factor` is denoted + :math:`\alpha`. .. attribute:: nsources diff --git a/boxtree/tree_build.py b/boxtree/tree_build.py index 9ad06925725a97a4aeab40e49aaf33dad34c65c1..2566dc44de18e858fff799457d8960fc8997c9d2 100644 --- a/boxtree/tree_build.py +++ b/boxtree/tree_build.py @@ -61,13 +61,13 @@ class TreeBuilder(object): @memoize_method def get_kernel_info(self, dimensions, coord_dtype, particle_id_dtype, box_id_dtype, - sources_are_targets, srcntgts_have_extent, + sources_are_targets, srcntgts_extent_norm, kind): from boxtree.tree_build_kernels import get_tree_build_kernel_info return get_tree_build_kernel_info(self.context, dimensions, coord_dtype, particle_id_dtype, box_id_dtype, - sources_are_targets, srcntgts_have_extent, + sources_are_targets, srcntgts_extent_norm, self.morton_nr_dtype, self.box_level_dtype, kind=kind) @@ -77,7 +77,9 @@ class TreeBuilder(object): max_particles_in_box=None, allocator=None, debug=False, targets=None, source_radii=None, target_radii=None, stick_out_factor=None, refine_weights=None, - max_leaf_refine_weight=None, wait_for=None, **kwargs): + max_leaf_refine_weight=None, wait_for=None, + extent_norm=None, + **kwargs): """ :arg queue: a :class:`pyopencl.CommandQueue` instance :arg particles: an object array of (XYZ) point coordinate arrays. @@ -114,6 +116,8 @@ class TreeBuilder(object): :arg wait_for: may either be *None* or a list of :class:`pyopencl.Event` instances for whose completion this command waits before starting execution. + :arg extent_norm: ``"l2"`` or ``"linf"``. Indicates the norm with respect + to which particle stick-out is measured. See :attr:`Tree.extent_norm`. :arg kwargs: Used internally for debugging. :returns: a tuple ``(tree, event)``, where *tree* is an instance of @@ -140,9 +144,22 @@ class TreeBuilder(object): sources_are_targets = targets is None sources_have_extent = source_radii is not None targets_have_extent = target_radii is not None + + if extent_norm is None: + extent_norm = "linf" + + if extent_norm not in ["linf", "l2"]: + raise ValueError("unexpected value of 'extent_norm': %s" + % extent_norm) + + srcntgts_extent_norm = extent_norm srcntgts_have_extent = sources_have_extent or targets_have_extent + if not srcntgts_have_extent: + srcntgts_extent_norm = None + + del extent_norm - if srcntgts_have_extent and targets is None: + if srcntgts_extent_norm and targets is None: raise ValueError("must specify targets when specifying " "any kind of radii") @@ -192,7 +209,7 @@ class TreeBuilder(object): knl_info = self.get_kernel_info(dimensions, coord_dtype, particle_id_dtype, box_id_dtype, - sources_are_targets, srcntgts_have_extent, + sources_are_targets, srcntgts_extent_norm, kind=kind) logger.info("tree build: start") @@ -1559,6 +1576,7 @@ class TreeBuilder(object): root_extent=root_extent, stick_out_factor=stick_out_factor, + extent_norm=srcntgts_extent_norm, bounding_box=(bbox_min, bbox_max), level_start_box_nrs=level_start_box_nrs, diff --git a/boxtree/tree_build_kernels.py b/boxtree/tree_build_kernels.py index a1588653d845487a80f4f981394bc0fc9dcc2013..5c9bfcbffb762b0077e810182508962e11424927 100644 --- a/boxtree/tree_build_kernels.py +++ b/boxtree/tree_build_kernels.py @@ -324,7 +324,8 @@ MORTON_NR_SCAN_PREAMBLE_TPL = Template(r"""//CL// %for ax in axis_names: // Most FMMs are isotropic, i.e. global_extent_{x,y,z} are all the same. // Nonetheless, the gain from exploiting this assumption seems so - // minimal that doing so here didn't seem worthwhile. + // minimal that doing so here didn't seem worthwhile in the + // srcntgts_extent_norm == "linf" case. coord_t global_min_${ax} = bbox->min_${ax}; coord_t global_extent_${ax} = bbox->max_${ax} - global_min_${ax}; @@ -346,19 +347,24 @@ MORTON_NR_SCAN_PREAMBLE_TPL = Template(r"""//CL// ((srcntgt_${ax} - global_min_${ax}) / global_extent_${ax}) * (1U << (1 + particle_level))); - %if srcntgts_have_extent: - // Need to compute center to compare excess with stick_out_factor. - coord_t next_level_box_center_${ax} = - global_min_${ax} - + global_extent_${ax} - * (${ax}_bits + one_half) - * next_level_box_size_factor; + // Need to compute center to compare excess with stick_out_factor. + // Unused if no stickout, relying on compiler to eliminate this. + const coord_t next_level_box_center_${ax} = + global_min_${ax} + + global_extent_${ax} + * (${ax}_bits + one_half) + * next_level_box_size_factor; - coord_t next_level_box_stick_out_radius_${ax} = + %endfor + + %if srcntgts_extent_norm == "linf": + %for ax in axis_names: + const coord_t next_level_box_stick_out_radius_${ax} = box_radius_factor * global_extent_${ax} * next_level_box_size_factor; + // stop descent here if particle sticks out of next-level box stop_srcntgt_descent = stop_srcntgt_descent || (srcntgt_${ax} + srcntgt_radius >= next_level_box_center_${ax} @@ -367,8 +373,41 @@ MORTON_NR_SCAN_PREAMBLE_TPL = Template(r"""//CL// (srcntgt_${ax} - srcntgt_radius < next_level_box_center_${ax} - next_level_box_stick_out_radius_${ax}); - %endif - %endfor + %endfor + + %elif srcntgts_extent_norm == "l2": + + coord_t next_level_box_stick_out_radius = + box_radius_factor + * global_extent_x /* assume isotropy */ + * next_level_box_size_factor; + + coord_t next_level_box_center_to_srcntgt_bdry_l2_dist = + sqrt( + %for ax in axis_names: + + (srcntgt_${ax} - next_level_box_center_${ax}) + * (srcntgt_${ax} - next_level_box_center_${ax}) + %endfor + ) + srcntgt_radius; + + // stop descent here if particle sticks out of next-level box + stop_srcntgt_descent = stop_srcntgt_descent || + ( + next_level_box_center_to_srcntgt_bdry_l2_dist + * next_level_box_center_to_srcntgt_bdry_l2_dist + >= ${dimensions} + * next_level_box_stick_out_radius + * next_level_box_stick_out_radius); + + %elif srcntgts_extent_norm is None: + // nothing to do + + %else: + <% + raise ValueError("unexpected value of 'srcntgts_extent_norm': %s" + % srcntgts_extent_norm) + %> + %endif // Pick off the lowest-order bit for each axis, put it in its place. int level_morton_number = 0 @@ -1244,8 +1283,11 @@ BOX_INFO_KERNEL_TPL = ElementwiseTemplate( def get_tree_build_kernel_info(context, dimensions, coord_dtype, particle_id_dtype, box_id_dtype, - sources_are_targets, srcntgts_have_extent, + sources_are_targets, srcntgts_extent_norm, morton_nr_dtype, box_level_dtype, kind): + """ + :arg srcntgts_extent_norm: one of ``None``, ``"l2"`` or ``"linf"`` + """ level_restrict = (kind == "adaptive-level-restricted") adaptive = not (kind == "non-adaptive") @@ -1269,7 +1311,7 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, dev = context.devices[0] morton_bin_count_dtype, _ = make_morton_bin_count_type( dev, dimensions, particle_id_dtype, - srcntgts_have_extent) + srcntgts_have_extent=srcntgts_extent_norm is not None) from boxtree.bounding_box import make_bounding_box_dtype bbox_dtype, bbox_type_decl = make_bounding_box_dtype( @@ -1301,7 +1343,8 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, level_restrict=level_restrict, sources_are_targets=sources_are_targets, - srcntgts_have_extent=srcntgts_have_extent, + srcntgts_have_extent=srcntgts_extent_norm is not None, + srcntgts_extent_norm=srcntgts_extent_norm, enable_assert=False, enable_printf=False, @@ -1379,12 +1422,12 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, + [VectorArg(coord_dtype, ax) for ax in axis_names] + ([VectorArg(coord_dtype, "srcntgt_radii")] - if srcntgts_have_extent else []) + if srcntgts_extent_norm is not None else []) ) morton_count_scan_arguments = list(common_arguments) - if srcntgts_have_extent: + if srcntgts_extent_norm is not None: morton_count_scan_arguments += [ (ScalarArg(coord_dtype, "stick_out_factor")) ] @@ -1402,7 +1445,7 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, ] + ["%s" % ax for ax in axis_names] + (["srcntgt_radii, stick_out_factor"] - if srcntgts_have_extent else []))), + if srcntgts_extent_norm is not None else []))), scan_expr="scan_t_add(a, b, across_seg_boundary)", neutral="scan_t_neutral()", is_segment_start_expr="box_start_flags[i]", @@ -1428,7 +1471,8 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, ), var_values=( ("dimensions", dimensions), - ("srcntgts_have_extent", srcntgts_have_extent), + ("srcntgts_have_extent", srcntgts_extent_norm is not None), + ("srcntgts_extent_norm", srcntgts_extent_norm), ("adaptive", adaptive), ("padded_bin", padded_bin), ("level_restrict", level_restrict), @@ -1520,7 +1564,7 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, # END KERNELS IN LEVEL LOOP - if srcntgts_have_extent: + if srcntgts_extent_norm is not None: extract_nonchild_srcntgt_count_kernel = \ EXTRACT_NONCHILD_SRCNTGT_COUNT_TPL.build( context, @@ -1636,7 +1680,7 @@ def get_tree_build_kernel_info(context, dimensions, coord_dtype, ("box_id_t", box_id_dtype), ), var_values=( - ("srcntgts_have_extent", srcntgts_have_extent), + ("srcntgts_have_extent", srcntgts_extent_norm is not None), ("sources_are_targets", sources_are_targets), ), more_preamble=generic_preamble) diff --git a/doc/images/linf-l2.png b/doc/images/linf-l2.png new file mode 100644 index 0000000000000000000000000000000000000000..0a24f867793b1aca0ca7808a80f8459be57c5dab Binary files /dev/null and b/doc/images/linf-l2.png differ diff --git a/doc/linf-l2.tikz b/doc/linf-l2.tikz new file mode 100644 index 0000000000000000000000000000000000000000..6bc330e72978325a73444032b7af6f2cfd052bff --- /dev/null +++ b/doc/linf-l2.tikz @@ -0,0 +1,41 @@ +\def\sout{0.25} +\def\pr{0.55} + +\begin{tikzpicture}[scale=1.5,baseline={(0,0)}] + \def\sout{0.25} + + \draw (-1,-1) rectangle (1,1); + \draw [dashed] (-1-\sout,-1-\sout) rectangle (1+\sout,1+\sout); + + \node [anchor=south] at (0,-1) {Box}; + \node [anchor=north] at (0,-1-\sout) {`Stick-out' region }; + \node [anchor=north] at (0,-1.85) {\texttt{'linf'}}; + \draw [|<->|] (0,0) -- (1+\sout,0) node [pos=0.5, anchor=north] {$(1+\alpha)R$}; + + \coordinate (pc) at (-0.25, 0.9); + \fill [red] (pc) circle (1pt); + \draw [red] (pc) ++(-\pr,-\pr) rectangle ++(2*\pr,2*\pr); + \draw [red,|<->|] (pc) -- ++(0,-\pr) node [pos=0.5,anchor=east] {$r$}; + + \node [anchor=north] at ($ (pc) + (0,-\pr)$) {Particle not in box}; + +\end{tikzpicture} +\begin{tikzpicture}[scale=1.5,baseline={(0,0)}] + \def\sqrttwo{1.4145} + + \draw (-1,-1) rectangle (1,1); + \draw [dashed] (0,0) circle ((\sqrttwo+\sout*\sqrttwo); + \draw [|<->|] (0,0) -- (-1-\sout,-1-\sout) node [pos=0.3, anchor=west] {$\sqrt2(1+\alpha)R$}; + + \node [anchor=south] at (0,-1) {Box}; + \node [anchor=north] at (0,-1-\sout) {`Stick-out' region }; + \node [anchor=north] at (0,-1.85) {\texttt{'l2'}}; + + \coordinate (pc) at (-0.25, 0.9); + \fill [green] (pc) circle (1pt); + \draw [green] (pc) circle (\pr); + \draw [green,|<->|] (pc) -- ++(0,-\pr) node [pos=0.5,anchor=east] {$r$}; + + \node [anchor=north] at ($ (pc) + (0,-\pr)$) {Particle in box}; + +\end{tikzpicture} diff --git a/doc/make-images.sh b/doc/make-images.sh new file mode 100755 index 0000000000000000000000000000000000000000..282480580e48f4daed4999da4b8550d0e6e9b7e1 --- /dev/null +++ b/doc/make-images.sh @@ -0,0 +1,3 @@ +#! /bin/bash + +tikz2png linf-l2.tikz images/linf-l2.png -density 200 diff --git a/doc/tikz2png b/doc/tikz2png new file mode 100755 index 0000000000000000000000000000000000000000..022897e76ada0f0275857f6b02edc2af36e15782 --- /dev/null +++ b/doc/tikz2png @@ -0,0 +1,57 @@ +#! /bin/bash + +set -e +set -x + +if test "$1" = "" || test "$2" = ""; then + echo "Usage: $0 pic.tikz pic.png [CONVERT FLAGS ...]" + exit 1 +fi + +TIKZ="$1" +OUTF="$2" +shift +shift + +TMPDIR=$(mktemp -d) +TEX="$TMPDIR/source.tex" + +cat < $TEX +\documentclass[preview]{standalone} +\nonstopmode +\usepackage{tikz} +\usetikzlibrary{calc} +\usetikzlibrary{positioning} +\usetikzlibrary{fadings} +\usetikzlibrary{chains} +\usetikzlibrary{scopes} +\usetikzlibrary{shadows} +\usetikzlibrary{arrows} +\usetikzlibrary{snakes} +\usetikzlibrary{shapes.misc} +\usetikzlibrary{shapes.symbols} +\usetikzlibrary{shapes.multipart} +\usetikzlibrary{fit} +\usetikzlibrary{shapes.arrows} +\usetikzlibrary{shapes.geometric} +\usetikzlibrary{shapes.callouts} +\usetikzlibrary{decorations.text} + +\pgfdeclarelayer{background} +\pgfdeclarelayer{foreground} +\pgfsetlayers{background,main,foreground} + +\renewcommand*\familydefault{\sfdefault} + +\begin{document} +END + +cat $TIKZ >> $TEX + +cat <> $TEX +\end{document} +END + +(cd "$TMPDIR"; pdflatex source) +(cd "$TMPDIR"; convert "$@" source.pdf source.png) +cp "$TMPDIR/source.png" "$OUTF" diff --git a/test/test_fmm.py b/test/test_fmm.py index a32fda8c2971b2b8eb4fb659eae354ae0cdea9a2..cb5e502818e184e330eb4dcd0b3e8ea731ed7e3e 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -241,38 +241,31 @@ class ConstantOneExpansionWranglerWithFilteredTargetsInUserOrder( @pytest.mark.parametrize("well_sep_is_n_away", [1, 2]) @pytest.mark.parametrize(("dims", "nsources_req", "ntargets_req", - "who_has_extent", "source_gen", "target_gen", "filter_kind"), + "who_has_extent", "source_gen", "target_gen", "filter_kind", + "extent_norm", "from_sep_smaller_crit"), [ - (2, 10**5, None, "", p_normal, p_normal, None), - (3, 5 * 10**4, 4*10**4, "", p_normal, p_normal, None), - #(2, 5 * 10**5, 4*10**4, "s", p_normal, p_normal, None), - #(2, 5 * 10**5, 4*10**4, "st", p_normal, p_normal, None), - (2, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None), - #(2, 5 * 10**5, 4*10**4, "st", p_surface, p_uniform, None), - - (3, 10**5, None, "", p_normal, p_normal, None), - (3, 5 * 10**4, 4*10**4, "", p_normal, p_normal, None), - #(3, 5 * 10**5, 4*10**4, "s", p_normal, p_normal, None), - #(3, 5 * 10**5, 4*10**4, "st", p_normal, p_normal, None), - (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None), - #(3, 5 * 10**5, 4*10**4, "st", p_surface, p_uniform, None), - - (2, 10**5, None, "", p_normal, p_normal, "user"), - (3, 5 * 10**4, 4*10**4, "", p_normal, p_normal, "user"), - #(2, 5 * 10**5, 4*10**4, "s", p_normal, p_normal, "user"), - #(2, 5 * 10**5, 4*10**4, "st", p_normal, p_normal, "user"), - (2, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, "user"), - #(2, 5 * 10**5, 4*10**4, "st", p_surface, p_uniform, "user"), - - (2, 10**5, None, "", p_normal, p_normal, "tree"), - (3, 5 * 10**4, 4*10**4, "", p_normal, p_normal, "tree"), - #(2, 5 * 10**5, 4*10**4, "s", p_normal, p_normal, "tree"), - #(2, 5 * 10**5, 4*10**4, "st", p_normal, p_normal, "tree"), - (2, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, "tree"), - #(2, 5 * 10**5, 4*10**4, "st", p_surface, p_uniform, "tree"), + (2, 10**5, None, "", p_normal, p_normal, None, "linf", "static_linf"), + (2, 5 * 10**4, 4*10**4, "", p_normal, p_normal, None, "linf", "static_linf"), # noqa: E501 + (2, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None, "linf", "static_linf"), # noqa: E501 + + (3, 10**5, None, "", p_normal, p_normal, None, "linf", "static_linf"), + (3, 5 * 10**5, 4*10**4, "", p_normal, p_normal, None, "linf", "static_linf"), # noqa: E501 + (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None, "linf", "static_linf"), # noqa: E501 + + (2, 10**5, None, "", p_normal, p_normal, "user", "linf", "static_linf"), + (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, "user", "linf", "static_linf"), # noqa: E501 + (2, 10**5, None, "", p_normal, p_normal, "tree", "linf", "static_linf"), + (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, "tree", "linf", "static_linf"), # noqa: E501 + + (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None, "linf", "static_linf"), # noqa: E501 + (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None, "linf", "precise_linf"), # noqa: E501 + (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None, "l2", "precise_linf"), # noqa: E501 + (3, 5 * 10**5, 4*10**4, "t", p_normal, p_normal, None, "l2", "static_l2"), # noqa: E501 + ]) def test_fmm_completeness(ctx_getter, dims, nsources_req, ntargets_req, - who_has_extent, source_gen, target_gen, filter_kind, well_sep_is_n_away): + who_has_extent, source_gen, target_gen, filter_kind, well_sep_is_n_away, + extent_norm, from_sep_smaller_crit): """Tests whether the built FMM traversal structures and driver completely capture all interactions. """ @@ -322,14 +315,16 @@ def test_fmm_completeness(ctx_getter, dims, nsources_req, ntargets_req, tree, _ = tb(queue, sources, targets=targets, max_particles_in_box=30, source_radii=source_radii, target_radii=target_radii, - debug=True, stick_out_factor=0.25) + debug=True, stick_out_factor=0.25, extent_norm=extent_norm) if 0: tree.get().plot() import matplotlib.pyplot as pt pt.show() from boxtree.traversal import FMMTraversalBuilder - tbuild = FMMTraversalBuilder(ctx, well_sep_is_n_away=well_sep_is_n_away) + tbuild = FMMTraversalBuilder(ctx, + well_sep_is_n_away=well_sep_is_n_away, + from_sep_smaller_crit=from_sep_smaller_crit) trav, _ = tbuild(queue, tree, debug=True) if who_has_extent: diff --git a/test/test_traversal.py b/test/test_traversal.py index 5f1513135c6a5fd135976b1fb495d1fb415fe991..6c4096e3b4dfde84617ce29324dff04a5c73d6a1 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -249,6 +249,30 @@ def test_tree_connectivity(ctx_getter, dims, sources_are_targets): # }}} + # {{{ box extents make sense + + for ibox in range(tree.nboxes): + ext_low, ext_high = tree.get_box_extent(ibox) + center = tree.box_centers[:, ibox] + + for which, bbox_min, bbox_max in [ + ( + "source", + trav.box_source_bounding_box_min[:, ibox], + trav.box_source_bounding_box_max[:, ibox]), + ( + "target", + trav.box_target_bounding_box_min[:, ibox], + trav.box_target_bounding_box_max[:, ibox]), + ]: + assert (ext_low <= bbox_min).all() + assert (bbox_min <= center).all() + + assert (bbox_max <= ext_high).all() + assert (center <= bbox_max).all() + + # }}} + # }}} diff --git a/test/test_tree.py b/test/test_tree.py index dff309f827ecf4647879e47711f9292a80650a96..f68a885edf93ff5828e37ca4d1e4af81c866f91b 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -444,7 +444,8 @@ def test_source_target_tree(ctx_getter, dims, do_plot=False): @pytest.mark.opencl @pytest.mark.parametrize("dims", [2, 3]) -def test_extent_tree(ctx_getter, dims, do_plot=False): +@pytest.mark.parametrize("extent_norm", ["linf", "l2"]) +def test_extent_tree(ctx_getter, dims, extent_norm, do_plot=False): logging.basicConfig(level=logging.INFO) ctx = ctx_getter() @@ -477,6 +478,7 @@ def test_extent_tree(ctx_getter, dims, do_plot=False): dev_tree, _ = tb(queue, sources, targets=targets, source_radii=source_radii, target_radii=target_radii, + extent_norm=extent_norm, refine_weights=refine_weights, max_leaf_refine_weight=20, @@ -555,9 +557,6 @@ def test_extent_tree(ctx_getter, dims, do_plot=False): for ibox in range(tree.nboxes): extent_low, extent_high = tree.get_box_extent(ibox) - box_radius = np.max(extent_high-extent_low) * 0.5 - stick_out_dist = tree.stick_out_factor * box_radius - assert (extent_low >= tree.bounding_box[0] - 1e-12*tree.root_extent).all(), ibox assert (extent_high <= @@ -573,6 +572,19 @@ def test_extent_tree(ctx_getter, dims, do_plot=False): + np.sum(tree.box_target_counts_cumul[existing_children]) == tree.box_target_counts_cumul[ibox]) + del existing_children + del box_children + + for ibox in range(tree.nboxes): + lev = int(tree.box_levels[ibox]) + box_radius = 0.5 * tree.root_extent / (1 << lev) + box_center = tree.box_centers[:, ibox] + extent_low = box_center - box_radius + extent_high = box_center + box_radius + + stick_out_dist = tree.stick_out_factor * box_radius + radius_with_stickout = (1 + tree.stick_out_factor) * box_radius + for what, starts, counts, points, radii in [ ("source", tree.box_source_starts, tree.box_source_counts_cumul, sorted_sources, sorted_source_radii), @@ -584,18 +596,33 @@ def test_extent_tree(ctx_getter, dims, do_plot=False): check_particles = points[:, bslice] check_radii = radii[bslice] - good = ( - (check_particles + check_radii - < extent_high[:, np.newaxis] + stick_out_dist) - & - (extent_low[:, np.newaxis] - stick_out_dist - <= check_particles - check_radii) - ).all(axis=0) + if extent_norm == "linf": + good = ( + (check_particles + check_radii + < extent_high[:, np.newaxis] + stick_out_dist) + & + (extent_low[:, np.newaxis] - stick_out_dist + <= check_particles - check_radii) + ).all(axis=0) + + elif extent_norm == "l2": + center_dists = np.sqrt( + np.sum( + (check_particles - box_center.reshape(-1, 1))**2, + axis=0)) + + good = ( + (center_dists + check_radii)**2 + < dims * radius_with_stickout**2) + + else: + raise ValueError("unexpected value of extent_norm") all_good_here = good.all() if not all_good_here: - print("BAD BOX %s %d level %d" % (what, ibox, tree.box_levels[ibox])) + print("BAD BOX %s %d level %d" + % (what, ibox, tree.box_levels[ibox])) all_good_so_far = all_good_so_far and all_good_here assert all_good_here