diff --git a/boxtree/tree.py b/boxtree/tree.py index 213a158f50c96de8e921ade6361f2e9469597590..de52fd9c0ae960ea0c1a5983e66a6dd465b9a560 100644 --- a/boxtree/tree.py +++ b/boxtree/tree.py @@ -431,182 +431,362 @@ class TreeWithLinkedPointSources(Tree): ``particle_id_t [nboxes]`` """ - def __init__(self, queue, tree, point_source_starts, point_sources, - debug=False): - """ - *Construction:* Requires that :attr:`Tree.sources_have_extent` is *True* - on *tree*. - :arg queue: a :class:`pyopencl.CommandQueue` instance - :arg point_source_starts: ``point_source_starts[isrc]`` and - ``point_source_starts[isrc+1]`` together indicate a ranges of point - particle indices in *point_sources* which will be linked to the - original (extent-having) source number *isrc*. *isrc* is in :ref:`user - source order `. +def link_point_sources(queue, tree, point_source_starts, point_sources, + debug=False): + """ + *Construction:* Requires that :attr:`Tree.sources_have_extent` is *True* + on *tree*. + + :arg queue: a :class:`pyopencl.CommandQueue` instance + :arg point_source_starts: ``point_source_starts[isrc]`` and + ``point_source_starts[isrc+1]`` together indicate a ranges of point + particle indices in *point_sources* which will be linked to the + original (extent-having) source number *isrc*. *isrc* is in :ref:`user + source order `. - All the particles linked to *isrc* shoud fall within the :math:`l^\infty` - 'circle' around particle number *isrc* with the radius drawn from - :attr:`source_radii`. + All the particles linked to *isrc* shoud fall within the :math:`l^\infty` + 'circle' around particle number *isrc* with the radius drawn from + :attr:`source_radii`. - :arg point_sources: an object array of (XYZ) point coordinate arrays. - """ + :arg point_sources: an object array of (XYZ) point coordinate arrays. + """ + + # The whole point of this routine is that all point sources within + # a box are reordered to be contiguous. + + logger.info("point source linking: start") + + if not tree.sources_have_extent: + raise ValueError("only allowed on trees whose sources have extent") + + npoint_sources_dev = cl.array.empty(queue, (), tree.particle_id_dtype) + + # {{{ compute tree_order_point_source_{starts, counts} + + # Scan over lengths of point source lists in tree order to determine + # indices of point source starts for each source. + + tree_order_point_source_starts = cl.array.empty( + queue, tree.nsources, tree.particle_id_dtype) + tree_order_point_source_counts = cl.array.empty( + queue, tree.nsources, tree.particle_id_dtype) + + from boxtree.tree_build_kernels import POINT_SOURCE_LINKING_SOURCE_SCAN_TPL + knl = POINT_SOURCE_LINKING_SOURCE_SCAN_TPL.build( + queue.context, + type_aliases=( + ("scan_t", tree.particle_id_dtype), + ("index_t", tree.particle_id_dtype), + ("particle_id_t", tree.particle_id_dtype), + ), + ) - # The whole point of this routine is that all point sources within - # a box are reordered to be contiguous. + logger.debug("point source linking: tree order source scan") - logger.info("point source linking: start") + knl(point_source_starts, tree.user_source_ids, + tree_order_point_source_starts, tree_order_point_source_counts, + npoint_sources_dev, size=tree.nsources, queue=queue) - if not tree.sources_have_extent: - raise ValueError("only allowed on trees whose sources have extent") + # }}} + + npoint_sources = int(npoint_sources_dev.get()) - npoint_sources_dev = cl.array.empty(queue, (), tree.particle_id_dtype) + # {{{ compute user_point_source_ids - # {{{ compute tree_order_point_source_{starts, counts} + # A list of point source starts, indexed in tree order, + # but giving point source indices in user order. + tree_order_index_user_point_source_starts = cl.array.take( + point_source_starts, tree.user_source_ids, + queue=queue) - # Scan over lengths of point source lists in tree order to determine - # indices of point source starts for each source. + user_point_source_ids = cl.array.empty( + queue, npoint_sources, tree.particle_id_dtype) + user_point_source_ids.fill(1) + cl.array.multi_put([tree_order_index_user_point_source_starts], + dest_indices=tree_order_point_source_starts, + out=[user_point_source_ids]) - tree_order_point_source_starts = cl.array.empty( - queue, tree.nsources, tree.particle_id_dtype) - tree_order_point_source_counts = cl.array.empty( - queue, tree.nsources, tree.particle_id_dtype) + if debug: + ups_host = user_point_source_ids.get() + assert (ups_host >= 0).all() + assert (ups_host < npoint_sources).all() - from boxtree.tree_build_kernels import POINT_SOURCE_LINKING_SOURCE_SCAN_TPL - knl = POINT_SOURCE_LINKING_SOURCE_SCAN_TPL.build( - queue.context, - type_aliases=( - ("scan_t", tree.particle_id_dtype), - ("index_t", tree.particle_id_dtype), - ("particle_id_t", tree.particle_id_dtype), - ), - ) + source_boundaries = cl.array.zeros(queue, npoint_sources, np.int8) - logger.debug("point source linking: tree order source scan") + # FIXME: Should be a scalar, in principle. + ones = cl.array.empty(queue, tree.nsources, np.int8) + ones.fill(1) - knl(point_source_starts, tree.user_source_ids, - tree_order_point_source_starts, tree_order_point_source_counts, - npoint_sources_dev, size=tree.nsources, queue=queue) + cl.array.multi_put( + [ones], + dest_indices=tree_order_point_source_starts, + out=[source_boundaries]) - # }}} + from boxtree.tree_build_kernels import \ + POINT_SOURCE_LINKING_USER_POINT_SOURCE_ID_SCAN_TPL - npoint_sources = int(npoint_sources_dev.get()) + logger.debug("point source linking: point source id scan") - # {{{ compute user_point_source_ids + knl = POINT_SOURCE_LINKING_USER_POINT_SOURCE_ID_SCAN_TPL.build( + queue.context, + type_aliases=( + ("scan_t", tree.particle_id_dtype), + ("index_t", tree.particle_id_dtype), + ("particle_id_t", tree.particle_id_dtype), + ), + ) + knl(source_boundaries, user_point_source_ids, + size=npoint_sources, queue=queue) - # A list of point source starts, indexed in tree order, - # but giving point source indices in user order. - tree_order_index_user_point_source_starts = cl.array.take( - point_source_starts, tree.user_source_ids, - queue=queue) + if debug: + ups_host = user_point_source_ids.get() + assert (ups_host >= 0).all() + assert (ups_host < npoint_sources).all() + + # }}} - user_point_source_ids = cl.array.empty( - queue, npoint_sources, tree.particle_id_dtype) - user_point_source_ids.fill(1) - cl.array.multi_put([tree_order_index_user_point_source_starts], - dest_indices=tree_order_point_source_starts, - out=[user_point_source_ids]) + from pytools.obj_array import make_obj_array + tree_order_point_sources = make_obj_array([ + cl.array.take(point_sources[i], user_point_source_ids, + queue=queue) + for i in range(tree.dimensions) + ]) - if debug: - ups_host = user_point_source_ids.get() - assert (ups_host >= 0).all() - assert (ups_host < npoint_sources).all() + # {{{ compute box point source metadata - source_boundaries = cl.array.zeros(queue, npoint_sources, np.int8) + from boxtree.tree_build_kernels import POINT_SOURCE_LINKING_BOX_POINT_SOURCES - # FIXME: Should be a scalar, in principle. - ones = cl.array.empty(queue, tree.nsources, np.int8) - ones.fill(1) + knl = POINT_SOURCE_LINKING_BOX_POINT_SOURCES.build( + queue.context, + type_aliases=( + ("particle_id_t", tree.particle_id_dtype), + ("box_id_t", tree.box_id_dtype), + ), + ) - cl.array.multi_put( - [ones], - dest_indices=tree_order_point_source_starts, - out=[source_boundaries]) + logger.debug("point source linking: box point sources") - from boxtree.tree_build_kernels import \ - POINT_SOURCE_LINKING_USER_POINT_SOURCE_ID_SCAN_TPL + box_point_source_starts = cl.array.empty( + queue, tree.nboxes, tree.particle_id_dtype) + box_point_source_counts_nonchild = cl.array.empty( + queue, tree.nboxes, tree.particle_id_dtype) + box_point_source_counts_cumul = cl.array.empty( + queue, tree.nboxes, tree.particle_id_dtype) - logger.debug("point source linking: point source id scan") + knl( + box_point_source_starts, box_point_source_counts_nonchild, + box_point_source_counts_cumul, - knl = POINT_SOURCE_LINKING_USER_POINT_SOURCE_ID_SCAN_TPL.build( - queue.context, - type_aliases=( - ("scan_t", tree.particle_id_dtype), - ("index_t", tree.particle_id_dtype), - ("particle_id_t", tree.particle_id_dtype), - ), - ) - knl(source_boundaries, user_point_source_ids, - size=npoint_sources, queue=queue) + tree.box_source_starts, tree.box_source_counts_nonchild, + tree.box_source_counts_cumul, + + tree_order_point_source_starts, + tree_order_point_source_counts, + range=slice(tree.nboxes), queue=queue) + + # }}} - if debug: - ups_host = user_point_source_ids.get() - assert (ups_host >= 0).all() - assert (ups_host < npoint_sources).all() + logger.info("point source linking: complete") + + tree_attrs = {} + for attr_name in tree.__class__.fields: + try: + tree_attrs[attr_name] = getattr(tree, attr_name) + except AttributeError: + pass + + return TreeWithLinkedPointSources( + npoint_sources=npoint_sources, + point_source_starts=tree_order_point_source_starts, + point_source_counts=tree_order_point_source_counts, + point_sources=tree_order_point_sources, + user_point_source_ids=user_point_source_ids, + box_point_source_starts=box_point_source_starts, + box_point_source_counts_nonchild=box_point_source_counts_nonchild, + box_point_source_counts_cumul=box_point_source_counts_cumul, + + **tree_attrs).with_queue(None) + + +# }}} - # }}} - from pytools.obj_array import make_obj_array - tree_order_point_sources = make_obj_array([ - cl.array.take(point_sources[i], user_point_source_ids, - queue=queue) - for i in range(tree.dimensions) +# {{{ filtered target lists + +class FilteredTargetListsInUserOrder(DeviceDataRecord): + """This class builds subsets of the list of targets in each box (as given by + :attr:`boxtree.Tree.box_target_starts` and + :attr:`boxtree.Tree.box_target_counts_cumul`).This subset is + specified by an array of *flags* in user target order. + + The list consists of target numbers in user target order. + See also :class:`FilteredTargetListsInTreeOrder`. + + .. attribute:: nfiltered_targets + + .. attribute:: target_starts + + ``particle_id_t [nboxes+1]`` + + Filtered list of targets in each box. Records start indices in + :attr:`boxtree.Tree.targets` for each box. Use together with + :attr:`target_counts_nonchild`. The lists for each box are + contiguous, so that ``target_starts[ibox+1]`` records the + end of the target list for *ibox*. + + .. attribute:: target_lists + + ``particle_id_t [nboxes]`` + + Filtered list of targets in each box. Records number of sources from + :attr:`boxtree.Tree.targets` in each box (excluding those belonging to + child boxes). Use together with :attr:`target_starts`. + """ + + +def filter_target_lists_in_user_order(queue, tree, flags): + """ + :arg flags: an array of length :attr:`boxtree.Tree.ntargets` of + :class:`numpy.int8` objects, which indicate by being zero that the + corresponding target (in user target order) is not part of the + filtered list, or by being nonzero that it is. + + :returns: A :class:`FilteredTargetListsInUserOrder` + """ + + user_order_flags = flags + del flags + + user_target_ids = cl.array.empty(queue, tree.ntargets, + tree.sorted_target_ids.dtype) + user_target_ids[tree.sorted_target_ids] = cl.array.arange( + queue, tree.ntargets, user_target_ids.dtype) + + from pyopencl.tools import VectorArg, dtype_to_ctype + from pyopencl.algorithm import ListOfListsBuilder + from mako.template import Template + builder = ListOfListsBuilder(queue.context, + [("filt_tgt_list", tree.particle_id_dtype)], Template("""//CL// + typedef ${dtype_to_ctype(particle_id_dtype)} particle_id_t; + + void generate(LIST_ARG_DECL USER_ARG_DECL index_type i) + { + particle_id_t b_t_start = box_target_starts[i]; + particle_id_t b_t_count = box_target_counts_nonchild[i]; + + for (particle_id_t j = b_t_start; j < b_t_start+b_t_count; ++j) + { + particle_id_t user_target_id = user_target_ids[j]; + if (user_order_flags[user_target_id]) + { + APPEND_filt_tgt_list(user_target_id); + } + } + } + """, strict_undefined=True).render( + dtype_to_ctype=dtype_to_ctype, + particle_id_dtype=tree.particle_id_dtype + ), arg_decls=[ + VectorArg(user_order_flags.dtype, "user_order_flags"), + VectorArg(tree.particle_id_dtype, "user_target_ids"), + VectorArg(tree.particle_id_dtype, "box_target_starts"), + VectorArg(tree.particle_id_dtype, "box_target_counts_nonchild"), ]) - # {{{ compute box point source metadata + result, evt = builder(queue, tree.nboxes, + user_order_flags.data, + user_target_ids.data, + tree.box_target_starts.data, tree.box_target_counts_nonchild.data) + + return FilteredTargetListsInUserOrder( + nfiltered_targets=result["filt_tgt_list"].count, + target_starts=result["filt_tgt_list"].starts, + target_lists=result["filt_tgt_list"].lists, + ).with_queue(None) + - from boxtree.tree_build_kernels import POINT_SOURCE_LINKING_BOX_POINT_SOURCES +class FilteredTargetListsInTreeOrder(DeviceDataRecord): + """This class builds subsets of the list of targets in each box (as given by + :attr:`boxtree.Tree.box_target_starts` and + :attr:`boxtree.Tree.box_target_counts_cumul`).This subset is + specified by an array of *flags* in user target order. - knl = POINT_SOURCE_LINKING_BOX_POINT_SOURCES.build( - queue.context, - type_aliases=( - ("particle_id_t", tree.particle_id_dtype), - ("box_id_t", tree.box_id_dtype), - ), - ) + The list consists of target numbers in (sorted) tree target order. + See also :class:`FilteredTargetListsInUserOrder`. - logger.debug("point source linking: box point sources") + .. attribute:: nfiltered_targets - box_point_source_starts = cl.array.empty( - queue, tree.nboxes, tree.particle_id_dtype) - box_point_source_counts_nonchild = cl.array.empty( - queue, tree.nboxes, tree.particle_id_dtype) - box_point_source_counts_cumul = cl.array.empty( - queue, tree.nboxes, tree.particle_id_dtype) + .. attribute:: target_starts - knl( - box_point_source_starts, box_point_source_counts_nonchild, - box_point_source_counts_cumul, + ``particle_id_t [nboxes+1]`` - tree.box_source_starts, tree.box_source_counts_nonchild, - tree.box_source_counts_cumul, + Filtered list of targets in each box. Records start indices in + :attr:`boxtree.Tree.targets` for each box. Use together with + :attr:`target_counts_nonchild`. The lists for each box are + contiguous, so that ``target_starts[ibox+1]`` records the + end of the target list for *ibox*. - tree_order_point_source_starts, - tree_order_point_source_counts, - range=slice(tree.nboxes), queue=queue) + .. attribute:: target_lists - # }}} + ``particle_id_t [nboxes]`` + + Filtered list of targets in each box. Records number of sources from + :attr:`boxtree.Tree.targets` in each box (excluding those belonging to + child boxes). Use together with :attr:`target_starts`. + """ - logger.info("point source linking: complete") - tree_attrs = {} - for attr_name in tree.__class__.fields: - try: - tree_attrs[attr_name] = getattr(tree, attr_name) - except AttributeError: - pass +def filter_target_lists_in_tree_order(queue, tree, flags): + """ + :arg flags: an array of length :attr:`boxtree.Tree.ntargets` of + :class:`numpy.int8` objects, which indicate by being zero that the + corresponding target (in user target order) is not part of the + filtered list, or by being nonzero that it is. + :returns: A :class:`FilteredTargetListsInTreeOrder` + """ - return Tree.__init__(self, - npoint_sources=npoint_sources, - point_source_starts=tree_order_point_source_starts, - point_source_counts=tree_order_point_source_counts, - point_sources=tree_order_point_sources, - user_point_source_ids=user_point_source_ids, - box_point_source_starts=box_point_source_starts, - box_point_source_counts_nonchild=box_point_source_counts_nonchild, - box_point_source_counts_cumul=box_point_source_counts_cumul, + tree_order_flags = cl.array.empty(queue, tree.ntargets, np.int8) + tree_order_flags[tree.sorted_target_ids] = flags + + from pyopencl.tools import VectorArg, dtype_to_ctype + from pyopencl.algorithm import ListOfListsBuilder + from mako.template import Template + builder = ListOfListsBuilder(queue.context, + [("filt_tgt_list", tree.particle_id_dtype)], Template("""//CL// + typedef ${dtype_to_ctype(particle_id_dtype)} particle_id_t; + + void generate(LIST_ARG_DECL USER_ARG_DECL index_type i) + { + particle_id_t b_t_start = box_target_starts[i]; + particle_id_t b_t_count = box_target_counts_nonchild[i]; + + for (particle_id_t j = b_t_start; j < b_t_start+b_t_count; ++j) + { + if (tree_order_flags[j]) + { + APPEND_filt_tgt_list(j); + } + } + } + """, strict_undefined=True).render( + dtype_to_ctype=dtype_to_ctype, + particle_id_dtype=tree.particle_id_dtype + ), arg_decls=[ + VectorArg(tree_order_flags.dtype, "tree_order_flags"), + VectorArg(tree.particle_id_dtype, "box_target_starts"), + VectorArg(tree.particle_id_dtype, "box_target_counts_nonchild"), + ]) - **tree_attrs) + result, evt = builder(queue, tree.nboxes, + tree_order_flags.data, + tree.box_target_starts.data, tree.box_target_counts_nonchild.data) + return FilteredTargetListsInTreeOrder( + nfiltered_targets=result["filt_tgt_list"].count, + target_starts=result["filt_tgt_list"].starts, + target_lists=result["filt_tgt_list"].lists, + ).with_queue(None) # }}} diff --git a/doc/tree.rst b/doc/tree.rst index bf6ca1e92f69103ae57d26a39655af62ff4edfb4..09a12bf4388c1da9c3a98ae1ca28b8078ac3adee 100644 --- a/doc/tree.rst +++ b/doc/tree.rst @@ -27,6 +27,29 @@ Tree with linked point sources .. automethod:: get +.. autofunction:: link_point_sources + +Filtering the lists of targets +------------------------------ + +.. currentmodule:: boxtree.tree + +.. autoclass:: FilteredTargetListsInUserOrder + + .. rubric:: Methods + + .. automethod:: get + +.. autofunction:: filter_target_lists_in_user_order + +.. autoclass:: FilteredTargetListsInTreeOrder + + .. rubric:: Methods + + .. automethod:: get + +.. autofunction:: filter_target_lists_in_tree_order + Build Entrypoint ---------------- diff --git a/test/test_tree.py b/test/test_tree.py index 13d1df86c59b883e81f0f1fbe33026411a36ebef..73e998978e9d23ccef5ca84b17f046bdc006cf70 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -463,8 +463,8 @@ def test_extent_tree(ctx_getter, dims, do_plot=False): 0, (nsources+1)*npoint_sources_per_source, npoint_sources_per_source, dtype=tree.particle_id_dtype) - from boxtree.tree import TreeWithLinkedPointSources - dev_tree = TreeWithLinkedPointSources(queue, dev_tree, + from boxtree.tree import link_point_sources + dev_tree = link_point_sources(queue, dev_tree, point_source_starts, point_sources, debug=True)