From af047bcad66879724e959de5e33dc59db9e473ae Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 14 Dec 2021 00:53:37 -0600 Subject: [PATCH] Track TreeIndependentData refactor from boxtree, remove persistent queue from wrangler --- sumpy/fmm.py | 262 ++++++++++++++++++++++++++--------------------- test/test_fmm.py | 50 +++++---- 2 files changed, 169 insertions(+), 143 deletions(-) diff --git a/sumpy/fmm.py b/sumpy/fmm.py index 26acf713..8636f17c 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -22,7 +22,7 @@ THE SOFTWARE. __doc__ = """Integrates :mod:`boxtree` with :mod:`sumpy`. -.. autoclass:: SumpyExpansionWranglerCodeContainer +.. autoclass:: SumpyTreeIndependentDataForWrangler .. autoclass:: SumpyExpansionWrangler """ @@ -31,6 +31,7 @@ import pyopencl as cl import pyopencl.array # noqa from pytools import memoize_method +from boxtree.fmm import TreeIndependentDataForWrangler, ExpansionWranglerInterface from sumpy import ( P2EFromSingleBox, P2EFromCSR, @@ -47,9 +48,9 @@ def level_to_rscale(tree, level): return tree.root_extent * (2**-level) -# {{{ expansion wrangler code container +# {{{ tree-independent data for wrangler -class SumpyExpansionWranglerCodeContainer: +class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler): """Objects of this type serve as a place to keep the code needed for :class:`SumpyExpansionWrangler`. Since :class:`SumpyExpansionWrangler` necessarily must have a :class:`pyopencl.CommandQueue`, but this queue @@ -86,6 +87,8 @@ class SumpyExpansionWranglerCodeContainer: self.strength_usage = strength_usage self.use_preprocessing_for_m2l = use_preprocessing_for_m2l + super().__init__() + self.cl_context = cl_context @memoize_method @@ -170,18 +173,6 @@ class SumpyExpansionWranglerCodeContainer: exclude_self=self.exclude_self, strength_usage=self.strength_usage) - def get_wrangler(self, queue, tree, dtype, fmm_level_to_order, - source_extra_kwargs=None, - kernel_extra_kwargs=None, - self_extra_kwargs=None, - translation_classes_data=None): - - if source_extra_kwargs is None: - source_extra_kwargs = {} - - return SumpyExpansionWrangler(self, queue, tree, dtype, fmm_level_to_order, - source_extra_kwargs, kernel_extra_kwargs, self_extra_kwargs, - translation_classes_data=translation_classes_data) # }}} @@ -202,7 +193,7 @@ class SumpyTimingFuture: @memoize_method def result(self): - from boxtree.fmm import TimingResult + from boxtree.timing import TimingResult if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE: from warnings import warn @@ -283,7 +274,7 @@ class SumpyTranslationClassesDataNotSuppliedWarning(UserWarning): # {{{ expansion wrangler -class SumpyExpansionWrangler: +class SumpyExpansionWrangler(ExpansionWranglerInterface): """Implements the :class:`boxtree.fmm.ExpansionWranglerInterface` by using :mod:`sumpy` expansions/translations. @@ -308,20 +299,18 @@ class SumpyExpansionWrangler: Type for the preprocessed multipole expansion if used for M2L. """ - def __init__(self, code_container, queue, tree, dtype, fmm_level_to_order, - source_extra_kwargs, + def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order, + source_extra_kwargs=None, kernel_extra_kwargs=None, self_extra_kwargs=None, translation_classes_data=None, preprocessed_mpole_dtype=None): - self.code = code_container - self.queue = queue - self.tree = tree + super().__init__(tree_indep, traversal) self.issued_timing_data_warning = False self.dtype = dtype - if not self.code.use_preprocessing_for_m2l: + if not self.tree_indep.use_preprocessing_for_m2l: # If not FFT, we don't need complex dtypes self.preprocessed_mpole_dtype = dtype elif preprocessed_mpole_dtype is not None: @@ -330,20 +319,21 @@ class SumpyExpansionWrangler: # FIXME: It is weird that the wrangler has to compute this. self.preprocessed_mpole_dtype = to_complex_dtype(dtype) + if source_extra_kwargs is None: + source_extra_kwargs = {} if kernel_extra_kwargs is None: kernel_extra_kwargs = {} - if self_extra_kwargs is None: self_extra_kwargs = {} if not callable(fmm_level_to_order): raise TypeError("fmm_level_to_order not passed") - base_kernel = code_container.get_base_kernel() + base_kernel = tree_indep.get_base_kernel() kernel_arg_set = frozenset(kernel_extra_kwargs.items()) self.level_orders = [ - fmm_level_to_order(base_kernel, kernel_arg_set, tree, lev) - for lev in range(tree.nlevels)] + fmm_level_to_order(base_kernel, kernel_arg_set, traversal.tree, lev) + for lev in range(traversal.tree.nlevels)] self.source_extra_kwargs = source_extra_kwargs self.kernel_extra_kwargs = kernel_extra_kwargs @@ -355,7 +345,7 @@ class SumpyExpansionWrangler: if base_kernel.is_translation_invariant: if translation_classes_data is None: from warnings import warn - if self.code.use_preprocessing_for_m2l: + if self.tree_indep.use_preprocessing_for_m2l: raise NotImplementedError( "FFT based List 2 (multipole-to-local) translations " "without translation_classes_data argument is not " @@ -375,9 +365,10 @@ class SumpyExpansionWrangler: self.supports_optimized_m2l = False self.translation_classes_data = translation_classes_data - self.use_preprocessing_for_m2l = self.code.use_preprocessing_for_m2l + self.use_preprocessing_for_m2l = self.tree_indep.use_preprocessing_for_m2l # {{{ data vector utilities + def _expansions_level_starts(self, order_to_size): return build_csr_level_starts(self.level_orders, order_to_size, self.tree.level_start_box_nrs) @@ -385,43 +376,60 @@ class SumpyExpansionWrangler: @memoize_method def multipole_expansions_level_starts(self): return self._expansions_level_starts( - lambda order: len(self.code.multipole_expansion(order))) + lambda order: len(self.tree_indep.multipole_expansion(order))) @memoize_method def local_expansions_level_starts(self): return self._expansions_level_starts( - lambda order: len(self.code.local_expansion(order))) + lambda order: len(self.tree_indep.local_expansion(order))) @memoize_method def m2l_translation_class_level_start_box_nrs(self): - data = self.translation_classes_data - return data.m2l_translation_classes_level_starts().get(self.queue) + with cl.CommandQueue(self.tree_indep.cl_context) as queue: + data = self.translation_classes_data + return data.m2l_translation_classes_level_starts().get(queue) @memoize_method def m2l_translation_classes_dependent_data_level_starts(self): def order_to_size(order): - mpole_expn = self.code.multipole_expansion(order) - local_expn = self.code.local_expansion(order) + mpole_expn = self.tree_indep.multipole_expansion(order) + local_expn = self.tree_indep.local_expansion(order) return local_expn.m2l_translation_classes_dependent_ndata(mpole_expn) return build_csr_level_starts(self.level_orders, order_to_size, level_starts=self.m2l_translation_class_level_start_box_nrs()) - def multipole_expansion_zeros(self): + def multipole_expansion_zeros(self, template_ary): + """Return an expansions array (which must support addition) + capable of holding one multipole or local expansion for every + box in the tree. + :arg template_ary: an array (not necessarily of the same shape or dtype as + the one to be created) whose run-time environment + (e.g. :class:`pyopencl.CommandQueue`) the returned array should + reuse. + """ return cl.array.zeros( - self.queue, + template_ary.queue, self.multipole_expansions_level_starts()[-1], dtype=self.dtype) - def local_expansion_zeros(self): + def local_expansion_zeros(self, template_ary): + """Return an expansions array (which must support addition) + capable of holding one multipole or local expansion for every + box in the tree. + :arg template_ary: an array (not necessarily of the same shape or dtype as + the one to be created) whose run-time environment + (e.g. :class:`pyopencl.CommandQueue`) the returned array should + reuse. + """ return cl.array.zeros( - self.queue, + template_ary.queue, self.local_expansions_level_starts()[-1], dtype=self.dtype) - def m2l_translation_classes_dependent_data_zeros(self): + def m2l_translation_classes_dependent_data_zeros(self, queue): return cl.array.zeros( - self.queue, + queue, self.m2l_translation_classes_dependent_data_level_starts()[-1], dtype=self.preprocessed_mpole_dtype) @@ -455,17 +463,17 @@ class SumpyExpansionWrangler: @memoize_method def m2l_preproc_mpole_expansions_level_starts(self): def order_to_size(order): - mpole_expn = self.code.multipole_expansion(order) - local_expn = self.code.local_expansion(order) + mpole_expn = self.tree_indep.multipole_expansion(order) + local_expn = self.tree_indep.local_expansion(order) res = local_expn.m2l_translation_classes_dependent_ndata(mpole_expn) return res return build_csr_level_starts(self.level_orders, order_to_size, level_starts=self.tree.level_start_box_nrs) - def m2l_preproc_mpole_expansion_zeros(self): + def m2l_preproc_mpole_expansion_zeros(self, template_ary): return cl.array.zeros( - self.queue, + template_ary.queue, self.m2l_preproc_mpole_expansions_level_starts()[-1], dtype=self.preprocessed_mpole_dtype) @@ -477,17 +485,27 @@ class SumpyExpansionWrangler: return (box_start, mpole_exps[expn_start:expn_stop].reshape(box_stop-box_start, -1)) - def output_zeros(self): + def output_zeros(self, template_ary): + """Return a potentials array (which must support addition) capable of + holding a potential value for each target in the tree. Note that + :func:`drive_fmm` makes no assumptions about *potential* other than + that it supports addition--it may consist of potentials, gradients of + the potential, or arbitrary other per-target output data. + :arg template_ary: an array (not necessarily of the same shape or dtype as + the one to be created) whose run-time environment + (e.g. :class:`pyopencl.CommandQueue`) the returned array should + reuse. + """ from pytools.obj_array import make_obj_array return make_obj_array([ cl.array.zeros( - self.queue, + template_ary.queue, self.tree.ntargets, dtype=self.dtype) - for k in self.code.target_kernels]) + for k in self.tree_indep.target_kernels]) def reorder_sources(self, source_array): - return source_array.with_queue(self.queue)[self.tree.user_source_ids] + return source_array.with_queue(source_array.queue)[self.tree.user_source_ids] def reorder_potentials(self, potentials): from pytools.obj_array import obj_array_vectorize @@ -497,7 +515,7 @@ class SumpyExpansionWrangler: and potentials.dtype.char == "O") def reorder(x): - return x.with_queue(self.queue)[self.tree.sorted_target_ids] + return x[self.tree.sorted_target_ids] return obj_array_vectorize(reorder, potentials) @@ -526,15 +544,16 @@ class SumpyExpansionWrangler: def form_multipoles(self, level_start_source_box_nrs, source_boxes, src_weight_vecs): - mpoles = self.multipole_expansion_zeros() + mpoles = self.multipole_expansion_zeros(src_weight_vecs[0]) kwargs = self.extra_kwargs.copy() kwargs.update(self.box_source_list_kwargs()) events = [] + queue = src_weight_vecs[0].queue for lev in range(self.tree.nlevels): - p2m = self.code.p2m(self.level_orders[lev]) + p2m = self.tree_indep.p2m(self.level_orders[lev]) start, stop = level_start_source_box_nrs[lev:lev+2] if start == stop: continue @@ -543,7 +562,7 @@ class SumpyExpansionWrangler: mpoles, lev) evt, (mpoles_res,) = p2m( - self.queue, + queue, source_boxes=source_boxes[start:stop], centers=self.tree.box_centers, strengths=src_weight_vecs, @@ -557,7 +576,7 @@ class SumpyExpansionWrangler: assert mpoles_res is mpoles_view - return (mpoles, SumpyTimingFuture(self.queue, events)) + return (mpoles, SumpyTimingFuture(queue, events)) def coarsen_multipoles(self, level_start_source_parent_box_nrs, @@ -566,6 +585,7 @@ class SumpyExpansionWrangler: tree = self.tree events = [] + queue = mpoles.queue # nlevels-1 is the last valid level index # nlevels-2 is the last valid level that could have children @@ -583,7 +603,7 @@ class SumpyExpansionWrangler: print("source", source_level, "empty") continue - m2m = self.code.m2m( + m2m = self.tree_indep.m2m( self.level_orders[source_level], self.level_orders[target_level]) @@ -593,7 +613,7 @@ class SumpyExpansionWrangler: self.multipole_expansions_view(mpoles, target_level) evt, (mpoles_res,) = m2m( - self.queue, + queue, src_expansions=source_mpoles_view, src_base_ibox=source_level_start_ibox, tgt_expansions=target_mpoles_view, @@ -614,11 +634,11 @@ class SumpyExpansionWrangler: if events: mpoles.add_event(events[-1]) - return (mpoles, SumpyTimingFuture(self.queue, events)) + return (mpoles, SumpyTimingFuture(queue, events)) def eval_direct(self, target_boxes, source_box_starts, source_box_lists, src_weight_vecs): - pot = self.output_zeros() + pot = self.output_zeros(src_weight_vecs[0]) kwargs = self.extra_kwargs.copy() kwargs.update(self.self_extra_kwargs) @@ -626,8 +646,9 @@ class SumpyExpansionWrangler: kwargs.update(self.box_target_list_kwargs()) events = [] + queue = src_weight_vecs[0].queue - evt, pot_res = self.code.p2p()(self.queue, + evt, pot_res = self.tree_indep.p2p()(queue, target_boxes=target_boxes, source_box_starts=source_box_starts, source_box_lists=source_box_lists, @@ -641,49 +662,52 @@ class SumpyExpansionWrangler: assert pot_i is pot_res_i pot_i.add_event(evt) - return (pot, SumpyTimingFuture(self.queue, events)) + return (pot, SumpyTimingFuture(queue, events)) @memoize_method def multipole_to_local_precompute(self, src_rscale): - m2l_translation_classes_dependent_data = \ - self.m2l_translation_classes_dependent_data_zeros() - for lev in range(self.tree.nlevels): - order = self.level_orders[lev] - precompute_kernel = \ - self.code.m2l_translation_class_dependent_data_kernel(order, order) + with cl.CommandQueue(self.tree_indep.cl_context) as queue: + m2l_translation_classes_dependent_data = \ + self.m2l_translation_classes_dependent_data_zeros(queue) + for lev in range(self.tree.nlevels): + order = self.level_orders[lev] + precompute_kernel = \ + self.tree_indep.m2l_translation_class_dependent_data_kernel( + order, order) - translation_classes_level_start, \ - m2l_translation_classes_dependent_data_view = \ - self.m2l_translation_classes_dependent_data_view( - m2l_translation_classes_dependent_data, lev) + translation_classes_level_start, \ + m2l_translation_classes_dependent_data_view = \ + self.m2l_translation_classes_dependent_data_view( + m2l_translation_classes_dependent_data, lev) - ntranslation_classes = \ - m2l_translation_classes_dependent_data_view.shape[0] + ntranslation_classes = \ + m2l_translation_classes_dependent_data_view.shape[0] - if ntranslation_classes == 0: - continue + if ntranslation_classes == 0: + continue - m2l_translation_vectors = ( - self.translation_classes_data.m2l_translation_vectors()) + m2l_translation_vectors = ( + self.translation_classes_data.m2l_translation_vectors()) + + evt, _ = precompute_kernel( + queue, + src_rscale=src_rscale, + translation_classes_level_start=translation_classes_level_start, + ntranslation_classes=ntranslation_classes, + m2l_translation_classes_dependent_data=( + m2l_translation_classes_dependent_data_view), + m2l_translation_vectors=m2l_translation_vectors, + ntranslation_vectors=m2l_translation_vectors.shape[1], + **self.kernel_extra_kwargs + ) + m2l_translation_classes_dependent_data.add_event(evt) - evt, _ = precompute_kernel( - self.queue, - src_rscale=src_rscale, - translation_classes_level_start=translation_classes_level_start, - ntranslation_classes=ntranslation_classes, - m2l_translation_classes_dependent_data=( - m2l_translation_classes_dependent_data_view), - m2l_translation_vectors=m2l_translation_vectors, - ntranslation_vectors=m2l_translation_vectors.shape[1], - **self.kernel_extra_kwargs - ) - m2l_translation_classes_dependent_data.add_event(evt) + m2l_translation_classes_dependent_data.finish() - m2l_translation_classes_dependent_data.finish() + m2l_translation_classes_dependent_data = \ + m2l_translation_classes_dependent_data.with_queue(None) - return (m2l_translation_classes_dependent_data, - SumpyTimingFuture(self.queue, - m2l_translation_classes_dependent_data.events[:])) + return m2l_translation_classes_dependent_data def _add_m2l_precompute_kwargs(self, kwargs_for_m2l, lev): @@ -694,7 +718,7 @@ class SumpyExpansionWrangler: if not self.supports_optimized_m2l: return src_rscale = kwargs_for_m2l["src_rscale"] - m2l_translation_classes_dependent_data, _ = \ + m2l_translation_classes_dependent_data = \ self.multipole_to_local_precompute(src_rscale) translation_classes_level_start, \ m2l_translation_classes_dependent_data_view = \ @@ -713,13 +737,15 @@ class SumpyExpansionWrangler: mpole_exps): precompute_evts = [] + queue = mpole_exps.queue if self.use_preprocessing_for_m2l: - preprocessed_mpole_exps = self.m2l_preproc_mpole_expansion_zeros() + preprocessed_mpole_exps = \ + self.m2l_preproc_mpole_expansion_zeros(mpole_exps) for lev in range(self.tree.nlevels): order = self.level_orders[lev] preprocess_mpole_kernel = \ - self.code.m2l_preprocess_mpole_kernel(order, order) + self.tree_indep.m2l_preprocess_mpole_kernel(order, order) source_level_start_ibox, source_mpoles_view = \ self.multipole_expansions_view(mpole_exps, lev) @@ -734,7 +760,7 @@ class SumpyExpansionWrangler: continue evt, _ = preprocess_mpole_kernel( - self.queue, + queue, src_expansions=source_mpoles_view, preprocessed_src_expansions=preprocessed_source_mpoles_view, src_rscale=level_to_rscale(self.tree, lev), @@ -747,7 +773,7 @@ class SumpyExpansionWrangler: mpole_exps_view_func = self.multipole_expansions_view events = [] - local_exps = self.local_expansion_zeros() + local_exps = self.local_expansion_zeros(mpole_exps) for lev in range(self.tree.nlevels): start, stop = level_start_target_box_nrs[lev:lev+2] @@ -755,7 +781,7 @@ class SumpyExpansionWrangler: continue order = self.level_orders[lev] - m2l = self.code.m2l(order, order, self.supports_optimized_m2l) + m2l = self.tree_indep.m2l(order, order, self.supports_optimized_m2l) source_level_start_ibox, source_mpoles_view = \ mpole_exps_view_func(mpole_exps, lev) @@ -783,20 +809,21 @@ class SumpyExpansionWrangler: kwargs["m2l_translation_classes_dependent_data"].size == 0: # There is nothing to do for this level continue - evt, _ = m2l(self.queue, **kwargs, wait_for=precompute_evts) + evt, _ = m2l(queue, **kwargs, wait_for=precompute_evts) events.append(evt) - return (local_exps, SumpyTimingFuture(self.queue, events)) + return (local_exps, SumpyTimingFuture(queue, events)) def eval_multipoles(self, target_boxes_by_source_level, source_boxes_by_level, mpole_exps): - pot = self.output_zeros() + pot = self.output_zeros(mpole_exps) kwargs = self.kernel_extra_kwargs.copy() kwargs.update(self.box_target_list_kwargs()) events = [] + queue = mpole_exps.queue wait_for = mpole_exps.events @@ -804,13 +831,13 @@ class SumpyExpansionWrangler: if len(target_boxes_by_source_level[isrc_level]) == 0: continue - m2p = self.code.m2p(self.level_orders[isrc_level]) + m2p = self.tree_indep.m2p(self.level_orders[isrc_level]) source_level_start_ibox, source_mpoles_view = \ self.multipole_expansions_view(mpole_exps, isrc_level) evt, pot_res = m2p( - self.queue, + queue, src_expansions=source_mpoles_view, src_base_ibox=source_level_start_ibox, @@ -837,17 +864,18 @@ class SumpyExpansionWrangler: for pot_i in pot: pot_i.add_event(events[-1]) - return (pot, SumpyTimingFuture(self.queue, events)) + return (pot, SumpyTimingFuture(queue, events)) def form_locals(self, level_start_target_or_target_parent_box_nrs, target_or_target_parent_boxes, starts, lists, src_weight_vecs): - local_exps = self.local_expansion_zeros() + local_exps = self.local_expansion_zeros(src_weight_vecs[0]) kwargs = self.extra_kwargs.copy() kwargs.update(self.box_source_list_kwargs()) events = [] + queue = src_weight_vecs[0].queue for lev in range(self.tree.nlevels): start, stop = \ @@ -855,13 +883,13 @@ class SumpyExpansionWrangler: if start == stop: continue - p2l = self.code.p2l(self.level_orders[lev]) + p2l = self.tree_indep.p2l(self.level_orders[lev]) target_level_start_ibox, target_local_exps_view = \ self.local_expansions_view(local_exps, lev) evt, (result,) = p2l( - self.queue, + queue, target_boxes=target_or_target_parent_boxes[start:stop], source_box_starts=starts[start:stop+1], source_box_lists=lists, @@ -878,7 +906,7 @@ class SumpyExpansionWrangler: assert result is target_local_exps_view - return (local_exps, SumpyTimingFuture(self.queue, events)) + return (local_exps, SumpyTimingFuture(queue, events)) def refine_locals(self, level_start_target_or_target_parent_box_nrs, @@ -886,6 +914,7 @@ class SumpyExpansionWrangler: local_exps): events = [] + queue = local_exps.queue for target_lev in range(1, self.tree.nlevels): start, stop = level_start_target_or_target_parent_box_nrs[ @@ -894,7 +923,7 @@ class SumpyExpansionWrangler: continue source_lev = target_lev - 1 - l2l = self.code.l2l( + l2l = self.tree_indep.l2l( self.level_orders[source_lev], self.level_orders[target_lev]) @@ -903,7 +932,7 @@ class SumpyExpansionWrangler: target_level_start_ibox, target_local_exps_view = \ self.local_expansions_view(local_exps, target_lev) - evt, (local_exps_res,) = l2l(self.queue, + evt, (local_exps_res,) = l2l(queue, src_expansions=source_local_exps_view, src_base_ibox=source_level_start_ibox, tgt_expansions=target_local_exps_view, @@ -923,28 +952,29 @@ class SumpyExpansionWrangler: local_exps.add_event(evt) - return (local_exps, SumpyTimingFuture(self.queue, [evt])) + return (local_exps, SumpyTimingFuture(queue, [evt])) def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps): - pot = self.output_zeros() + pot = self.output_zeros(local_exps) kwargs = self.kernel_extra_kwargs.copy() kwargs.update(self.box_target_list_kwargs()) events = [] + queue = local_exps.queue for lev in range(self.tree.nlevels): start, stop = level_start_target_box_nrs[lev:lev+2] if start == stop: continue - l2p = self.code.l2p(self.level_orders[lev]) + l2p = self.tree_indep.l2p(self.level_orders[lev]) source_level_start_ibox, source_local_exps_view = \ self.local_expansions_view(local_exps, lev) evt, pot_res = l2p( - self.queue, + queue, src_expansions=source_local_exps_view, src_base_ibox=source_level_start_ibox, @@ -961,9 +991,9 @@ class SumpyExpansionWrangler: for pot_i, pot_res_i in zip(pot, pot_res): assert pot_i is pot_res_i - return (pot, SumpyTimingFuture(self.queue, events)) + return (pot, SumpyTimingFuture(queue, events)) - def finalize_potentials(self, potentials): + def finalize_potentials(self, potentials, template_ary): return potentials # }}} diff --git a/test/test_fmm.py b/test/test_fmm.py index 69be8a2f..3e2f84fd 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -36,7 +36,11 @@ from sumpy.expansion.local import ( VolumeTaylorLocalExpansion, H2DLocalExpansion, Y2DLocalExpansion, LinearPDEConformingVolumeTaylorLocalExpansion) -from sumpy.fmm import SumpyTranslationClassesData +from sumpy.fmm import ( + SumpyTreeIndependentDataForWrangler, + SumpyExpansionWrangler, + SumpyTranslationClassesData, + SumpyTranslationClassesDataNotSuppliedWarning) import pytest import warnings @@ -177,15 +181,12 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, for order in order_values: target_kernels = [knl] - from sumpy.fmm import (SumpyExpansionWranglerCodeContainer, - SumpyTranslationClassesDataNotSuppliedWarning) - if optimized_m2l: translation_classes_data = SumpyTranslationClassesData(queue, trav) else: translation_classes_data = None - wcc = SumpyExpansionWranglerCodeContainer( + tree_indep = SumpyTreeIndependentDataForWrangler( ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), @@ -195,14 +196,14 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, if not optimized_m2l: warnings.simplefilter("ignore", SumpyTranslationClassesDataNotSuppliedWarning) - wrangler = wcc.get_wrangler(queue, tree, dtype, + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, kernel_extra_kwargs=extra_kwargs, translation_classes_data=translation_classes_data) from boxtree.fmm import drive_fmm - pot, = drive_fmm(trav, wrangler, (weights,)) + pot, = drive_fmm(wrangler, (weights,)) from sumpy import P2P p2p = P2P(ctx, target_kernels, exclude_self=False) @@ -292,21 +293,20 @@ def test_unified_single_and_double(ctx_factory): source_extra_kwargs = {} if deriv_knl in source_kernels: source_extra_kwargs["dir_vec"] = dir_vec - from sumpy.fmm import SumpyExpansionWranglerCodeContainer - wcc = SumpyExpansionWranglerCodeContainer( + tree_indep = SumpyTreeIndependentDataForWrangler( ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), target_kernels=target_kernels, source_kernels=source_kernels, strength_usage=strength_usage) - wrangler = wcc.get_wrangler(queue, tree, dtype, + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, source_extra_kwargs=source_extra_kwargs, translation_classes_data=SumpyTranslationClassesData(queue, trav)) from boxtree.fmm import drive_fmm - pot = drive_fmm(trav, wrangler, weights) + pot = drive_fmm(wrangler, weights) results.append(np.array([pot[0].get(), pot[1].get()])) ref_pot = results[0] + results[1] @@ -355,20 +355,19 @@ def test_sumpy_fmm_timing_data_collection(ctx_factory): from functools import partial - from sumpy.fmm import SumpyExpansionWranglerCodeContainer - wcc = SumpyExpansionWranglerCodeContainer( + tree_indep = SumpyTreeIndependentDataForWrangler( ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), target_kernels) - wrangler = wcc.get_wrangler(queue, tree, dtype, + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, translation_classes_data=SumpyTranslationClassesData(queue, trav)) from boxtree.fmm import drive_fmm timing_data = {} - pot, = drive_fmm(trav, wrangler, (weights,), timing_data=timing_data) + pot, = drive_fmm(wrangler, (weights,), timing_data=timing_data) print(timing_data) assert timing_data @@ -413,22 +412,21 @@ def test_sumpy_fmm_exclude_self(ctx_factory): from functools import partial - from sumpy.fmm import SumpyExpansionWranglerCodeContainer - wcc = SumpyExpansionWranglerCodeContainer( + tree_indep = SumpyTreeIndependentDataForWrangler( ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), target_kernels, exclude_self=True) - wrangler = wcc.get_wrangler(queue, tree, dtype, + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, self_extra_kwargs=self_extra_kwargs, translation_classes_data=SumpyTranslationClassesData(queue, trav)) from boxtree.fmm import drive_fmm - pot, = drive_fmm(trav, wrangler, (weights,)) + pot, = drive_fmm(wrangler, (weights,)) from sumpy import P2P p2p = P2P(ctx, target_kernels, exclude_self=True) @@ -482,14 +480,13 @@ def test_sumpy_axis_source_derivative(ctx_factory): from functools import partial - from sumpy.fmm import SumpyExpansionWranglerCodeContainer from sumpy.kernel import AxisTargetDerivative, AxisSourceDerivative pots = [] for tgt_knl, src_knl in [(AxisTargetDerivative(0, knl), knl), (knl, AxisSourceDerivative(0, knl))]: - wcc = SumpyExpansionWranglerCodeContainer( + tree_indep = SumpyTreeIndependentDataForWrangler( ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), @@ -497,14 +494,14 @@ def test_sumpy_axis_source_derivative(ctx_factory): source_kernels=[src_knl], exclude_self=True) - wrangler = wcc.get_wrangler(queue, tree, dtype, + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, self_extra_kwargs=self_extra_kwargs, translation_classes_data=SumpyTranslationClassesData(queue, trav)) from boxtree.fmm import drive_fmm - pot, = drive_fmm(trav, wrangler, (weights,)) + pot, = drive_fmm(wrangler, (weights,)) pots.append(pot.get()) rel_err = la.norm(pots[0] + pots[1]) / la.norm(pots[0]) @@ -552,7 +549,6 @@ def test_sumpy_target_point_multiplier(ctx_factory, deriv_axes): from functools import partial - from sumpy.fmm import SumpyExpansionWranglerCodeContainer from sumpy.kernel import TargetPointMultiplier, AxisTargetDerivative tgt_knls = [TargetPointMultiplier(0, knl), knl, knl] @@ -560,7 +556,7 @@ def test_sumpy_target_point_multiplier(ctx_factory, deriv_axes): tgt_knls[0] = AxisTargetDerivative(axis, tgt_knls[0]) tgt_knls[1] = AxisTargetDerivative(axis, tgt_knls[1]) - wcc = SumpyExpansionWranglerCodeContainer( + tree_indep = SumpyTreeIndependentDataForWrangler( ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), @@ -568,14 +564,14 @@ def test_sumpy_target_point_multiplier(ctx_factory, deriv_axes): source_kernels=[knl], exclude_self=True) - wrangler = wcc.get_wrangler(queue, tree, dtype, + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, self_extra_kwargs=self_extra_kwargs, translation_classes_data=SumpyTranslationClassesData(queue, trav)) from boxtree.fmm import drive_fmm - pot0, pot1, pot2 = drive_fmm(trav, wrangler, (weights,)) + pot0, pot1, pot2 = drive_fmm(wrangler, (weights,)) pot0, pot1, pot2 = pot0.get(), pot1.get(), pot2.get() if deriv_axes == (0,): ref_pot = pot1 * sources[0].get() + pot2 -- GitLab