diff --git a/boxtree/area_query.py b/boxtree/area_query.py index 88cc4b625b233c9c5b5fcc7517929f3ee38e242b..a5ebdca20f178531bbe0fab095f322ceaf0941f6 100644 --- a/boxtree/area_query.py +++ b/boxtree/area_query.py @@ -835,16 +835,24 @@ class LeavesToBallsLookupBuilder(object): logger.info("leaves-to-balls lookup: expand starts") - nkeys = len(area_query.leaves_near_ball_lists) + nkeys = tree.nboxes nballs_p_1 = len(area_query.leaves_near_ball_starts) assert nballs_p_1 == len(ball_radii) + 1 + # We invert the area query in two steps: + # + # 1. Turn the area query result into (ball number, box number) pairs. + # This is done in the "starts expander kernel." + # + # 2. Key-value sort the (ball number, box number) pairs by box number. + starts_expander_knl = self.get_starts_expander_kernel(tree.box_id_dtype) - expanded_starts = cl.array.empty(queue, nkeys, tree.box_id_dtype) + expanded_starts = cl.array.empty( + queue, len(area_query.leaves_near_ball_lists), tree.box_id_dtype) evt = starts_expander_knl( - expanded_starts, - area_query.leaves_near_ball_starts.with_queue(queue), - nballs_p_1) + expanded_starts, + area_query.leaves_near_ball_starts.with_queue(queue), + nballs_p_1) wait_for = [evt] logger.info("leaves-to-balls lookup: key-value sort") diff --git a/test/test_tree.py b/test/test_tree.py index 9fc08334b5a2f35bb2b9f66893accfb0a60374d8..0857787724a13cc43535d41cba7b713c0c6c16f5 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -636,6 +636,8 @@ def test_leaves_to_balls_query(ctx_getter, dims, do_plot=False): ball_centers = np.array([x.get() for x in ball_centers]).T ball_radii = ball_radii.get() + assert len(lbl.balls_near_box_starts) == tree.nboxes + 1 + from boxtree import box_flags_enum for ibox in range(tree.nboxes):