diff --git a/boxtree/area_query.py b/boxtree/area_query.py index 2742a26f1601c6083a5beef659f79df9a3bfc75d..3e7742d1f19205ad181e9fca1cf69f9d2589cf5c 100644 --- a/boxtree/area_query.py +++ b/boxtree/area_query.py @@ -132,107 +132,94 @@ class LeavesToBallsLookup(DeviceDataRecord): # }}} -# {{{ kernel templates - -AREA_QUERY_TEMPLATE = r"""//CL// -typedef ${dtype_to_ctype(ball_id_dtype)} ball_id_t; -typedef ${dtype_to_ctype(peer_list_idx_dtype)} peer_list_idx_t; -<%def name="add_box_to_list_if_overlaps_ball(box_id)"> - { - bool is_overlapping; +# {{{ kernel templates - ${check_l_infty_ball_overlap( - "is_overlapping", box_id, "ball_radius", "ball_center")} +GUIDING_BOX_FINDER_MACRO = r"""//CL:mako// + <%def name="find_guiding_box(ball_center, ball_radius)"> + ${walk_init(0)} + box_id_t guiding_box; - if (is_overlapping) + if (LEVEL_TO_RAD(0) < ${ball_radius} / 2 + || !(box_flags[0] & BOX_HAS_CHILDREN)) { - APPEND_leaves(${box_id}); + guiding_box = 0; + continue_walk = false; } - } - - -void generate(LIST_ARG_DECL USER_ARG_DECL ball_id_t ball_nr) -{ - coord_vec_t ball_center; - %for i in range(dimensions): - ball_center.${AXIS_NAMES[i]} = ball_${AXIS_NAMES[i]}[ball_nr]; - %endfor - - coord_t ball_radius = ball_radii[ball_nr]; - - /////////////////////////////////// - // Step 1: Find the guiding box. // - /////////////////////////////////// - ${walk_init(0)} - box_id_t guiding_box; + while (continue_walk) + { + // Get the next child. + box_id_t child_box_id = box_child_ids[ + walk_morton_nr * aligned_nboxes + walk_box_id]; - if (LEVEL_TO_RAD(0) < ball_radius / 2 || !(box_flags[0] & BOX_HAS_CHILDREN)) - { - guiding_box = 0; - continue_walk = false; - } + bool last_child = walk_morton_nr == ${2**dimensions - 1}; - while (continue_walk) - { - // Get the next child. - box_id_t child_box_id = box_child_ids[ - walk_morton_nr * aligned_nboxes + walk_box_id]; + if (child_box_id) + { + bool contains_ball_center; + int child_level = walk_level + 1; + coord_t child_rad = LEVEL_TO_RAD(child_level); - bool last_child = walk_morton_nr == ${2**dimensions - 1}; + { + // Check if the child contains the ball's center. + ${load_center("child_center", "child_box_id")} - if (child_box_id) - { - bool contains_ball_center; - int child_level = walk_level + 1; - coord_t child_rad = LEVEL_TO_RAD(child_level); + coord_t max_dist = 0; + %for i in range(dimensions): + max_dist = fmax(max_dist, + distance(${ball_center}.s${i}, child_center.s${i})); + %endfor - { - // Check if the child contains the ball's center. - ${load_center("child_center", "child_box_id")} + contains_ball_center = max_dist <= child_rad; + } - coord_t max_dist = 0; - %for i in range(dimensions): - max_dist = fmax(max_dist, - fabs(ball_center.s${i} - child_center.s${i})); - %endfor + if (contains_ball_center) + { + if ((child_rad / 2 < ${ball_radius} + && ${ball_radius} <= child_rad) || + !(box_flags[child_box_id] & BOX_HAS_CHILDREN)) + { + guiding_box = child_box_id; + break; + } - contains_ball_center = max_dist <= child_rad; + // We want to descend into this box. Put the current state + // on the stack. + ${walk_push("child_box_id")} + continue; + } } - if (contains_ball_center) + if (last_child) { - if ((child_rad / 2 < ball_radius && ball_radius <= child_rad) || - !(box_flags[child_box_id] & BOX_HAS_CHILDREN)) - { - guiding_box = child_box_id; - break; - } - - // We want to descend into this box. Put the current state - // on the stack. - ${walk_push("child_box_id")} - continue; + // This box has no children that contain the center, so it must + // be the guiding box. + guiding_box = walk_box_id; + break; } - } - if (last_child) - { - // This box has no children that contain the center, so it must - // be the guiding box. - guiding_box = walk_box_id; - break; + ${walk_advance()} } + +""" - ${walk_advance()} - } + +AREA_QUERY_WALKER_BODY = r""" + coord_vec_t ball_center; + coord_t ball_radius; + ${get_ball_center_and_radius("ball_center", "ball_radius", "i")} + + /////////////////////////////////// + // Step 1: Find the guiding box. // + /////////////////////////////////// + + ${find_guiding_box("ball_center", "ball_radius")} ////////////////////////////////////////////////////// // Step 2 - Walk the peer boxes to find the leaves. // ////////////////////////////////////////////////////// - for (peer_list_idx_t pb_i = peer_list_starts[guiding_box], pb_e = peer_list_starts[guiding_box+1]; pb_i < pb_e; ++pb_i) { @@ -240,7 +227,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL ball_id_t ball_nr) if (!(box_flags[peer_box] & BOX_HAS_CHILDREN)) { - ${add_box_to_list_if_overlaps_ball("peer_box")} + ${leaf_found_op("peer_box", "ball_center", "ball_radius")} } else { @@ -255,7 +242,8 @@ void generate(LIST_ARG_DECL USER_ARG_DECL ball_id_t ball_nr) { if (!(box_flags[child_box_id] & BOX_HAS_CHILDREN)) { - ${add_box_to_list_if_overlaps_ball("child_box_id")} + ${leaf_found_op("child_box_id", "ball_center", + "ball_radius")} } else { @@ -270,9 +258,44 @@ void generate(LIST_ARG_DECL USER_ARG_DECL ball_id_t ball_nr) } } } -} """ + +AREA_QUERY_TEMPLATE = ( + GUIDING_BOX_FINDER_MACRO + r"""//CL// + typedef ${dtype_to_ctype(ball_id_dtype)} ball_id_t; + typedef ${dtype_to_ctype(peer_list_idx_dtype)} peer_list_idx_t; + + <%def name="get_ball_center_and_radius(ball_center, ball_radius, i)"> + %for ax in AXIS_NAMES[:dimensions]: + ${ball_center}.${ax} = ball_${ax}[${i}]; + %endfor + ${ball_radius} = ball_radii[${i}]; + + + <%def name="leaf_found_op(leaf_box_id, ball_center, ball_radius)"> + { + bool is_overlapping; + + ${check_l_infty_ball_overlap( + "is_overlapping", leaf_box_id, ball_radius, ball_center)} + + if (is_overlapping) + { + APPEND_leaves(${leaf_box_id}); + } + } + + + void generate(LIST_ARG_DECL USER_ARG_DECL ball_id_t i) + { + """ + + AREA_QUERY_WALKER_BODY + + """ + } + """) + + PEER_LIST_FINDER_TEMPLATE = r"""//CL// void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t box_id) @@ -360,6 +383,7 @@ void generate(LIST_ARG_DECL USER_ARG_DECL box_id_t box_id) from pyopencl.elementwise import ElementwiseTemplate +from boxtree.tools import InlineBinarySearch STARTS_EXPANDER_TEMPLATE = ElementwiseTemplate( @@ -370,29 +394,115 @@ STARTS_EXPANDER_TEMPLATE = ElementwiseTemplate( """, operation=r"""//CL// /* Find my index in starts, place the index in dst. */ - idx_t l_idx = 0, r_idx = starts_len - 1, my_idx; + dst[i] = bsearch(starts, starts_len, i); + """, + name="starts_expander", + preamble=str(InlineBinarySearch("idx_t"))) - for (;;) - { - my_idx = (l_idx + r_idx) / 2; +# }}} - if (starts[my_idx] <= i && i < starts[my_idx + 1]) - { - dst[i] = my_idx; - break; - } - if (starts[my_idx] > i) - { - r_idx = my_idx - 1; - } - else - { - l_idx = my_idx + 1; - } - } - """, - name="starts_expander") +# {{{ area query elementwise template + +class AreaQueryElementwiseTemplate(object): + """ + Experimental: Intended as a way to perform operations in the body of an area + query. + """ + + @staticmethod + def unwrap_args(tree, peer_lists, *args): + return (tree.box_centers, + tree.root_extent, + tree.box_levels, + tree.aligned_nboxes, + tree.box_child_ids, + tree.box_flags, + peer_lists.peer_list_starts, + peer_lists.peer_lists) + args + + def __init__(self, extra_args, ball_center_and_radius_expr, + leaf_found_op, preamble="", name="area_query_elwise"): + + def wrap_in_macro(decl, expr): + return """ + <%def name=\"{decl}\"> + {expr} + + """.format(decl=decl, expr=expr) + + from boxtree.traversal import TRAVERSAL_PREAMBLE_MAKO_DEFS + + self.elwise_template = ElementwiseTemplate( + arguments=r"""//CL:mako// + coord_t *box_centers, + coord_t root_extent, + box_level_t *box_levels, + box_id_t aligned_nboxes, + box_id_t *box_child_ids, + box_flags_t *box_flags, + peer_list_idx_t *peer_list_starts, + box_id_t *peer_lists, + """ + extra_args, + operation="//CL:mako//\n" + + wrap_in_macro("get_ball_center_and_radius(ball_center, ball_radius, i)", + ball_center_and_radius_expr) + + wrap_in_macro("leaf_found_op(leaf_box_id, ball_center, ball_radius)", + leaf_found_op) + + TRAVERSAL_PREAMBLE_MAKO_DEFS + + GUIDING_BOX_FINDER_MACRO + + AREA_QUERY_WALKER_BODY, + name=name, + preamble=preamble) + + def generate(self, context, + dimensions, coord_dtype, box_id_dtype, + peer_list_idx_dtype, max_levels, + extra_var_values=(), extra_type_aliases=(), + extra_preamble=""): + from pyopencl.tools import dtype_to_ctype + from boxtree import box_flags_enum + from boxtree.traversal import TRAVERSAL_PREAMBLE_TYPEDEFS_AND_DEFINES + + render_vars = ( + ("dimensions", dimensions), + ("dtype_to_ctype", dtype_to_ctype), + ("box_id_dtype", box_id_dtype), + ("particle_id_dtype", None), + ("coord_dtype", coord_dtype), + ("vec_types", tuple(cl.array.vec.types.items())), + ("max_levels", max_levels), + ("AXIS_NAMES", AXIS_NAMES), + ("box_flags_enum", box_flags_enum), + ("peer_list_idx_dtype", peer_list_idx_dtype), + ("debug", False), + # Not used (but required by TRAVERSAL_PREAMBLE_TEMPLATE) + ("stick_out_factor", 0), + ) + + preamble = Template( + # HACK: box_flags_t and coord_t are defined here and + # in the template below, so disable typedef redifinition warnings. + """ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wtypedef-redefinition" + """ + + TRAVERSAL_PREAMBLE_TYPEDEFS_AND_DEFINES + + """ + #pragma clang diagnostic pop + """, + strict_undefined=True).render(**dict(render_vars)) + + return self.elwise_template.build(context, + type_aliases=( + ("coord_t", coord_dtype), + ("box_id_t", box_id_dtype), + ("peer_list_idx_t", peer_list_idx_dtype), + ("box_level_t", np.uint8), + ("box_flags_t", box_flags_enum.dtype), + ) + extra_type_aliases, + var_values=render_vars + extra_var_values, + more_preamble=preamble + extra_preamble) # }}} @@ -540,9 +650,9 @@ class AreaQueryBuilder(object): leaves_near_ball_starts=result["leaves"].starts, leaves_near_ball_lists=result["leaves"].lists).with_queue(None), evt - # }}} + # {{{ area query transpose (leaves-to-balls) lookup build class LeavesToBallsLookupBuilder(object): @@ -641,9 +751,9 @@ class LeavesToBallsLookupBuilder(object): balls_near_box_starts=balls_near_box_starts, balls_near_box_lists=balls_near_box_lists).with_queue(None), evt - # }}} + # {{{ peer list build class PeerListFinder(object): diff --git a/boxtree/tools.py b/boxtree/tools.py index 638c5fa72d0c135d6420c2248a38ca8e0ee7db23..b048712045585f62954441a21029ea4ebf861115 100644 --- a/boxtree/tools.py +++ b/boxtree/tools.py @@ -501,4 +501,47 @@ class MapValuesKernel(object): # }}} + +# {{{ binary search + +from mako.template import Template + + +BINARY_SEARCH_TEMPLATE = Template(""" +inline size_t bsearch(__global ${idx_t} *starts, size_t len, ${idx_t} val) +{ + size_t l_idx = 0, r_idx = len - 1, my_idx; + for (;;) + { + my_idx = (l_idx + r_idx) / 2; + + if (starts[my_idx] <= val && val < starts[my_idx + 1]) + { + return my_idx; + } + + if (starts[my_idx] > val) + { + r_idx = my_idx - 1; + } + else + { + l_idx = my_idx + 1; + } + } +} +""") + + +class InlineBinarySearch(object): + + def __init__(self, idx_t): + self.idx_t = idx_t + + @memoize_method + def __str__(self): + return BINARY_SEARCH_TEMPLATE.render(idx_t=self.idx_t) + +# }}} + # vim: foldmethod=marker:filetype=pyopencl diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 98cd59007a8f1249a149b96af3568bac1490d897..10da4447253b4465d8cef08afc2b98e83c9d8349 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -133,8 +133,10 @@ typedef ${dtype_to_ctype(box_id_dtype)} box_id_t; %if particle_id_dtype is not None: typedef ${dtype_to_ctype(particle_id_dtype)} particle_id_t; %endif +## Convert to dict first, as this may be passed as a tuple-of-tuples. +<% vec_types_dict = dict(vec_types) %> typedef ${dtype_to_ctype(coord_dtype)} coord_t; -typedef ${dtype_to_ctype(vec_types[coord_dtype, dimensions])} coord_vec_t; +typedef ${dtype_to_ctype(vec_types_dict[coord_dtype, dimensions])} coord_vec_t; #define NLEVELS ${max_levels} #define STICK_OUT_FACTOR ((coord_t) ${stick_out_factor}) diff --git a/test/test_tree.py b/test/test_tree.py index 18210aaf282f0bc3f77b34668662d96197f8ff75..5f91f88f080762a94d9efbc2ea65746a58cb4e0c 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -308,12 +308,8 @@ def test_explicit_refine_weights_particle_tree(ctx_getter, dtype, dims, nparticles = 10**5 from pyopencl.clrandom import PhiloxGenerator - import random - random.seed(10) - rng = PhiloxGenerator(ctx) - refine_weights = cl.array.empty(queue, nparticles, np.int32) - evt = rng.fill_uniform(refine_weights, a=1, b=10) - cl.wait_for_events([evt]) + rng = PhiloxGenerator(ctx, seed=10) + refine_weights = rng.uniform(queue, nparticles, dtype=np.int32, a=1, b=10) run_build_test(builder, queue, dims, dtype, nparticles, refine_weights=refine_weights, max_leaf_refine_weight=100, @@ -725,6 +721,70 @@ def test_area_query(ctx_getter, dims, do_plot=False): actual = near_leaves assert sorted(found) == sorted(actual) + +@pytest.mark.opencl +@pytest.mark.area_query +@pytest.mark.parametrize("dims", [2, 3]) +def test_area_query_elwise(ctx_getter, dims, do_plot=False): + ctx = ctx_getter() + queue = cl.CommandQueue(ctx) + + nparticles = 10**5 + dtype = np.float64 + + particles = make_normal_particle_array(queue, nparticles, dims, dtype) + + if do_plot: + import matplotlib.pyplot as pt + pt.plot(particles[0].get(), particles[1].get(), "x") + + from boxtree import TreeBuilder + tb = TreeBuilder(ctx) + + queue.finish() + tree, _ = tb(queue, particles, max_particles_in_box=30, debug=True) + + nballs = 10**4 + ball_centers = make_normal_particle_array(queue, nballs, dims, dtype) + ball_radii = cl.array.empty(queue, nballs, dtype).fill(0.1) + + from boxtree.area_query import ( + AreaQueryElementwiseTemplate, PeerListFinder) + + template = AreaQueryElementwiseTemplate( + extra_args=""" + coord_t *ball_radii, + %for ax in AXIS_NAMES[:dimensions]: + coord_t *ball_${ax}, + %endfor + """, + ball_center_and_radius_expr=""" + %for ax in AXIS_NAMES[:dimensions]: + ${ball_center}.${ax} = ball_${ax}[${i}]; + %endfor + ${ball_radius} = ball_radii[${i}]; + """, + leaf_found_op="") + + peer_lists, evt = PeerListFinder(ctx)(queue, tree) + + kernel = template.generate( + ctx, + dims, + tree.coord_dtype, + tree.box_id_dtype, + peer_lists.peer_list_starts.dtype, + tree.nlevels) + + evt = kernel( + *template.unwrap_args( + tree, peer_lists, ball_radii, *ball_centers), + queue=queue, + wait_for=[evt], + range=slice(len(ball_radii))) + + cl.wait_for_events([evt]) + # }}}