diff --git a/boxtree/fmm.py b/boxtree/fmm.py index bc0c773246c06d7f75013fbebbdbbf910479dd3d..dc38cda0a75ffc740a66f0a1f960a9f9c749eb49 100644 --- a/boxtree/fmm.py +++ b/boxtree/fmm.py @@ -113,8 +113,7 @@ def drive_fmm(traversal, expansion_wrangler, src_weights): # contribution *out* of the downward-propagating local expansions) potentials = potentials + wrangler.eval_multipoles( - traversal.level_start_target_box_nrs, - traversal.target_boxes, + traversal.target_boxes_sep_smaller_by_source_level, traversal.from_sep_smaller_by_level, mpole_exps) @@ -273,12 +272,13 @@ class ExpansionWranglerInterface: :meth:`local_expansion_zeros`. """ - def eval_multipoles(self, level_start_target_box_nrs, target_boxes, - starts, lists, mpole_exps): - """For each box in *target_boxes*, evaluate the multipole expansion in - *mpole_exps* in the nearby boxes given in *starts* and *lists*, and - return a new potential array. *starts* and *lists* use :ref:`csr` and - *starts* is indexed like *target_boxes*. + def eval_multipoles(self, + target_boxes_by_source_level, from_sep_smaller_by_level, mpole_exps): + """For a level *i*, each box in *target_boxes_by_source_level[i]*, evaluate + the multipole expansion in *mpole_exps* in the nearby boxes given in + *from_sep_smaller_by_level*, and return a new potential array. + *starts* and *lists* in *from_sep_smaller_by_level[i]* use :ref:`csr` + and *starts* is indexed like *target_boxes_by_source_level[i]*. :returns: a new potential array, see :meth:`output_zeros`. """ diff --git a/boxtree/pyfmmlib_integration.py b/boxtree/pyfmmlib_integration.py index 812eae4aa4d044bb5f198e8d8b8674066014057b..693c13aba11ecbec037b047c9a7acd3e3c7b01a9 100644 --- a/boxtree/pyfmmlib_integration.py +++ b/boxtree/pyfmmlib_integration.py @@ -631,8 +631,9 @@ class FMMLibExpansionWrangler(object): return local_exps - def eval_multipoles(self, level_start_target_box_nrs, target_boxes, - sep_smaller_nonsiblings_by_level, mpole_exps): + def eval_multipoles(self, + target_boxes_by_source_level, sep_smaller_nonsiblings_by_level, + mpole_exps): output = self.output_zeros() mpeval = self.get_expn_eval_routine("mp") @@ -643,7 +644,8 @@ class FMMLibExpansionWrangler(object): rscale = self.level_to_rscale(isrc_level) - for itgt_box, tgt_ibox in enumerate(target_boxes): + for itgt_box, tgt_ibox in \ + enumerate(target_boxes_by_source_level[isrc_level]): tgt_pslice = self._get_target_slice(tgt_ibox) if tgt_pslice.stop - tgt_pslice.start == 0: diff --git a/boxtree/tools.py b/boxtree/tools.py index ece152d1257b0c7827bebc842b72e0c5436d3a11..68ac5f430d8d9681a8e58208da17c88ffe6b4fe5 100644 --- a/boxtree/tools.py +++ b/boxtree/tools.py @@ -270,10 +270,11 @@ class DeviceDataRecord(Record): elif isinstance(val, list): return [transform_val(i) for i in val] elif isinstance(val, BuiltList): - return BuiltList( - count=val.count, - starts=f(val.starts), - lists=f(val.lists)) + transformed_list = {} + for field in val.__dict__: + if field != 'count' and not field.startswith('_'): + transformed_list[field] = f(getattr(val, field)) + return BuiltList(count=val.count, **transformed_list) else: return f(val) diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 3448ee56a0f359160cc0bfcec17957029b56cbba..6caadc36385f81642be3bacb28b9030d1c9b4c45 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -1305,17 +1305,30 @@ class FMMTraversalInfo(DeviceDataRecord): Indexed like :attr:`target_or_target_parent_boxes`. See :ref:`csr`. + .. attribute:: target_boxes_sep_smaller_by_source_level + + A list of arrays, one per level, indicating which target boxes are used with + the interaction list entries of :attr:`from_sep_smaller_by_level`. + ``target_boxes_sep_smaller_by_source_level[i]`` has length + ``from_sep_smaller_by_level[i].num_nonempty_lists`. + + .. attribute:: from_sep_smaller_by_level A list of :attr:`boxtree.Tree.nlevels` (corresponding to the levels on which each listed source box resides) objects, each of which has - attributes *count*, *starts* and *lists*, which form a CSR list of List - 3 source boxes. + attributes *count*, *starts*, *lists*, *num_nonempty_lists*, and + *nonempty_indices*, which form a CSR list of List 3 source boxes. - *starts* has shape/type ``box_id_t [ntarget_boxes+1]``. *lists* is of type - ``box_id_t``. (Note: This list contains global box numbers, not + *starts* has shape/type ``box_id_t [num_nonempty_lists+1]``. *lists* is of + type ``box_id_t``. (Note: This list contains global box numbers, not indices into :attr:`source_boxes`.) + Note *starts* are indexed by `target_boxes_sep_smaller_by_source_level`. For + example, for level *i*, *lists[starts[j]:starts[j+1]]* represents "List 3" + source boxes of *target_boxes_sep_smaller_by_source_level[i][j]* on level + *i*. + .. attribute:: from_sep_close_smaller_starts ``box_id_t [ntarget_boxes+1]`` (or *None*) @@ -1692,13 +1705,13 @@ class FMMTraversalBuilder: VectorArg(box_flags_enum.dtype, "box_flags"), ] - for list_name, template, extra_args, extra_lists in [ + for list_name, template, extra_args, extra_lists, eliminate_empty_list in [ ("same_level_non_well_sep_boxes", - SAME_LEVEL_NON_WELL_SEP_BOXES_TEMPLATE, [], []), + SAME_LEVEL_NON_WELL_SEP_BOXES_TEMPLATE, [], [], []), ("neighbor_source_boxes", NEIGBHOR_SOURCE_BOXES_TEMPLATE, [ VectorArg(box_id_dtype, "target_boxes"), - ], []), + ], [], []), ("from_sep_siblings", FROM_SEP_SIBLINGS_TEMPLATE, [ VectorArg(box_id_dtype, "target_or_target_parent_boxes"), @@ -1707,7 +1720,7 @@ class FMMTraversalBuilder: "same_level_non_well_sep_boxes_starts"), VectorArg(box_id_dtype, "same_level_non_well_sep_boxes_lists"), - ], []), + ], [], []), ("from_sep_smaller", FROM_SEP_SMALLER_TEMPLATE, [ ScalarArg(coord_dtype, "stick_out_factor"), @@ -1725,7 +1738,7 @@ class FMMTraversalBuilder: ], ["from_sep_close_smaller"] if sources_have_extent or targets_have_extent - else []), + else [], ["from_sep_smaller"]), ("from_sep_bigger", FROM_SEP_BIGGER_TEMPLATE, [ ScalarArg(coord_dtype, "stick_out_factor"), @@ -1738,7 +1751,7 @@ class FMMTraversalBuilder: ], ["from_sep_close_bigger"] if sources_have_extent or targets_have_extent - else []), + else [], []), ]: src = Template( TRAVERSAL_PREAMBLE_TEMPLATE @@ -1753,7 +1766,8 @@ class FMMTraversalBuilder: str(src), arg_decls=base_args + extra_args, debug=debug, name_prefix=list_name, - complex_kernel=True) + complex_kernel=True, + eliminate_empty_output_lists=eliminate_empty_list) # }}} @@ -2032,6 +2046,7 @@ class FMMTraversalBuilder: from_sep_smaller_wait_for = [] from_sep_smaller_by_level = [] + target_boxes_sep_smaller_by_source_level = [] for ilevel in range(tree.nlevels): fin_debug("finding separated smaller ('list 3 level %d')" % ilevel) @@ -2041,7 +2056,11 @@ class FMMTraversalBuilder: omit_lists=("from_sep_close_smaller",) if with_extent else (), wait_for=wait_for) + target_boxes_sep_smaller = target_boxes[ + result["from_sep_smaller"].nonempty_indices] + from_sep_smaller_by_level.append(result["from_sep_smaller"]) + target_boxes_sep_smaller_by_source_level.append(target_boxes_sep_smaller) from_sep_smaller_wait_for.append(evt) if with_extent: @@ -2137,6 +2156,8 @@ class FMMTraversalBuilder: from_sep_siblings_lists=from_sep_siblings.lists, from_sep_smaller_by_level=from_sep_smaller_by_level, + target_boxes_sep_smaller_by_source_level=( + target_boxes_sep_smaller_by_source_level), from_sep_close_smaller_starts=from_sep_close_smaller_starts, from_sep_close_smaller_lists=from_sep_close_smaller_lists, diff --git a/test/test_fmm.py b/test/test_fmm.py index e29d8bd0198c27b56b289f3e8e897dbc0dbc7b7d..fd10c7538f9d12ec2add04e1e2fabb78d076cf3f 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -141,12 +141,14 @@ class ConstantOneExpansionWrangler(object): return local_exps - def eval_multipoles(self, level_start_target_box_nrs, target_boxes, - from_sep_smaller_nonsiblings_by_level, mpole_exps): + def eval_multipoles(self, + target_boxes_by_source_level, from_sep_smaller_nonsiblings_by_level, + mpole_exps): pot = self.potential_zeros() - for ssn in from_sep_smaller_nonsiblings_by_level: - for itgt_box, tgt_ibox in enumerate(target_boxes): + for level, ssn in enumerate(from_sep_smaller_nonsiblings_by_level): + for itgt_box, tgt_ibox in \ + enumerate(target_boxes_by_source_level[level]): tgt_pslice = self._get_target_slice(tgt_ibox) contrib = 0 diff --git a/test/test_traversal.py b/test/test_traversal.py index 6c4096e3b4dfde84617ce29324dff04a5c73d6a1..8f7935e713194e8ee1b3d4a70b172c76b76fc39e 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -140,9 +140,10 @@ def test_tree_connectivity(ctx_getter, dims, sources_are_targets): assert (trav.target_or_target_parent_boxes == np.arange(tree.nboxes)).all() # {{{ list 4 <= list 3 - for itarget_box, ibox in enumerate(trav.target_boxes): - for ssn in trav.from_sep_smaller_by_level: + for level, ssn in enumerate(trav.from_sep_smaller_by_level): + for itarget_box, ibox in \ + enumerate(trav.target_boxes_sep_smaller_by_source_level[level]): start, end = ssn.starts[itarget_box:itarget_box+2] for jbox in ssn.lists[start:end]: @@ -155,10 +156,19 @@ def test_tree_connectivity(ctx_getter, dims, sources_are_targets): # {{{ list 4 <= list 3 - box_to_target_box_index = np.empty(tree.nboxes, tree.box_id_dtype) - box_to_target_box_index.fill(-1) - box_to_target_box_index[trav.target_boxes] = np.arange( - len(trav.target_boxes), dtype=tree.box_id_dtype) + box_to_target_boxes_sep_smaller_by_source_level = [] + for level in range(trav.tree.nlevels): + box_to_target_boxes_sep_smaller = np.empty( + tree.nboxes, tree.box_id_dtype) + box_to_target_boxes_sep_smaller.fill(-1) + box_to_target_boxes_sep_smaller[ + trav.target_boxes_sep_smaller_by_source_level[level] + ] = np.arange( + len(trav.target_boxes_sep_smaller_by_source_level[level]), + dtype=tree.box_id_dtype + ) + box_to_target_boxes_sep_smaller_by_source_level.append( + box_to_target_boxes_sep_smaller) assert (trav.source_boxes == trav.target_boxes).all() assert (trav.target_or_target_parent_boxes == np.arange( @@ -173,13 +183,14 @@ def test_tree_connectivity(ctx_getter, dims, sources_are_targets): # are the same thing (i.e. leaves--see assertion above), so we # may treat them as targets anyhow. - jtgt_box = box_to_target_box_index[jbox] - assert jtgt_box != -1 - good = False - for ssn in trav.from_sep_smaller_by_level: - rstart, rend = ssn.starts[jtgt_box:jtgt_box+2] + for level, ssn in enumerate(trav.from_sep_smaller_by_level): + jtgt_box = \ + box_to_target_boxes_sep_smaller_by_source_level[level][jbox] + if jtgt_box == -1: + continue + rstart, rend = ssn.starts[jtgt_box:jtgt_box + 2] good = good or ibox in ssn.lists[rstart:rend] if not good: @@ -207,8 +218,11 @@ def test_tree_connectivity(ctx_getter, dims, sources_are_targets): # {{{ from_sep_smaller satisfies relative level assumption - for itarget_box, ibox in enumerate(trav.target_boxes): - for ssn in trav.from_sep_smaller_by_level: + # for itarget_box, ibox in enumerate(trav.target_boxes): + # for ssn in trav.from_sep_smaller_by_level: + for level, ssn in enumerate(trav.from_sep_smaller_by_level): + for itarget_box, ibox in enumerate( + trav.target_boxes_sep_smaller_by_source_level[level]): start, end = ssn.starts[itarget_box:itarget_box+2] for jbox in ssn.lists[start:end]: