diff --git a/boxtree/tree_build.py b/boxtree/tree_build.py index 1647bfca9fc9da6d16d1cca92651c454b2e9c0b1..9670487d91b588d9ee5fb8081d8a9e891bd3ac01 100644 --- a/boxtree/tree_build.py +++ b/boxtree/tree_build.py @@ -83,7 +83,7 @@ class TreeBuilder(object): targets=None, source_radii=None, target_radii=None, stick_out_factor=None, refine_weights=None, max_leaf_refine_weight=None, wait_for=None, - extent_norm=None, + extent_norm=None, bbox=None, **kwargs): """ :arg queue: a :class:`pyopencl.CommandQueue` instance @@ -123,6 +123,17 @@ class TreeBuilder(object): execution. :arg extent_norm: ``"l2"`` or ``"linf"``. Indicates the norm with respect to which particle stick-out is measured. See :attr:`Tree.extent_norm`. + :arg bbox: Bounding box of either type: + 1. A dim-by-2 array, with each row to be [min, max] coordinates + in its corresponding axis direction. + 2. (Internal use only) of the same type as returned by + *boxtree.bounding_box.make_bounding_box_dtype*. + When given, this bounding box is used for tree + building. Otherwise, the bounding box is determined from particles + in such a way that it is square and is slightly larger at the top (so + that scaled coordinates are always < 1). + When supplied, the bounding box must be square and have all the + particles in its closure. :arg kwargs: Used internally for debugging. :returns: a tuple ``(tree, event)``, where *tree* is an instance of @@ -342,22 +353,60 @@ class TreeBuilder(object): # {{{ find and process bounding box - bbox, _ = self.bbox_finder(srcntgts, srcntgt_radii, wait_for=wait_for) - bbox = bbox.get() + if bbox is None: + bbox, _ = self.bbox_finder(srcntgts, srcntgt_radii, wait_for=wait_for) + bbox = bbox.get() - root_extent = max( + root_extent = max( bbox["max_"+ax] - bbox["min_"+ax] for ax in axis_names) * (1+TreeBuilder.ROOT_EXTENT_STRETCH_FACTOR) - # make bbox square and slightly larger at the top, to ensure scaled - # coordinates are always < 1 - bbox_min = np.empty(dimensions, coord_dtype) - for i, ax in enumerate(axis_names): - bbox_min[i] = bbox["min_"+ax] + # make bbox square and slightly larger at the top, to ensure scaled + # coordinates are always < 1 + bbox_min = np.empty(dimensions, coord_dtype) + for i, ax in enumerate(axis_names): + bbox_min[i] = bbox["min_"+ax] - bbox_max = bbox_min + root_extent - for i, ax in enumerate(axis_names): - bbox["max_"+ax] = bbox_max[i] + bbox_max = bbox_min + root_extent + for i, ax in enumerate(axis_names): + bbox["max_"+ax] = bbox_max[i] + else: + # Validate that bbox is a superset of particle-derived bbox + bbox_auto, _ = self.bbox_finder( + srcntgts, srcntgt_radii, wait_for=wait_for) + bbox_auto = bbox_auto.get() + + # Convert unstructured numpy array to bbox_type + if isinstance(bbox, np.ndarray): + if len(bbox) == dimensions: + bbox_bak = bbox.copy() + bbox = np.empty(1, bbox_auto.dtype) + for i, ax in enumerate(axis_names): + bbox['min_'+ax] = bbox_bak[i][0] + bbox['max_'+ax] = bbox_bak[i][1] + else: + assert len(bbox) == 1 + else: + raise NotImplementedError("Unsupported bounding box type: " + + str(type(bbox))) + + # bbox must cover bbox_auto + bbox_min = np.empty(dimensions, coord_dtype) + bbox_max = np.empty(dimensions, coord_dtype) + + for i, ax in enumerate(axis_names): + bbox_min[i] = bbox["min_" + ax] + bbox_max[i] = bbox["max_" + ax] + assert bbox_min[i] < bbox_max[i] + assert bbox_min[i] <= bbox_auto["min_" + ax] + assert bbox_max[i] >= bbox_auto["max_" + ax] + + # bbox must be a square + bbox_exts = bbox_max - bbox_min + for ext in bbox_exts: + assert abs(ext - bbox_exts[0]) < 1e-15 + + root_extent = bbox_exts[0] # }}}