From 189c13d2de3027740cab4f6ff9830623797e85f3 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sun, 25 Jun 2017 21:09:05 -0500 Subject: [PATCH 1/7] Start working on an FMM infrastructure for UnregularizedLayerPotentialSource. --- pytential/unregularized.py | 172 ++++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/pytential/unregularized.py b/pytential/unregularized.py index 4ea94bf1..c0256aeb 100644 --- a/pytential/unregularized.py +++ b/pytential/unregularized.py @@ -25,9 +25,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + import six +import numpy as np +import loopy as lp + +from boxtre.tools import DeviceDataRecord from pytential.source import LayerPotentialSourceBase +from pytools import memoize_method + +import pyopencl as cl +import pyopencl.array # noqa import logging logger = logging.getLogger(__name__) @@ -46,14 +55,30 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): """ def __init__(self, density_discr, + fmm_order=None, + fmm_level_to_order=None, # begin undocumented arguments # FIXME default debug=False once everything works debug=True): """ + :arg fmm_order: `False` for direct calculation. """ self.density_discr = density_discr self.debug = debug + if fmm_order is not None and fmm_level_to_order is not None: + raise TypeError("may not specify both fmm_order and fmm_level_to_order") + + if fmm_level_to_order is None: + if fmm_order is not None: + def fmm_level_to_order(level): + return fmm_order + + self.density_discr = density_discr + self.fmm_level_to_order = fmm_level_to_order + + self.debug = debug + @property def fine_density_discr(self): return self.density_discr @@ -67,9 +92,12 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): def copy( self, density_discr=None, - debug=None + fmm_level_to_order=None, + debug=None, ): return type(self)( + fmm_level_to_order=( + fmm_level_to_order or self.fmm_level_to_order), density_discr=density_discr or self.density_discr, debug=debug if debug is not None else self.debug) @@ -129,6 +157,148 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): # }}} +# {{{ fmm tools + +class _FMMGeometryCodeContainer(object): + + def __init__(self, cl_context, ambient_dim, debug): + self.cl_context = cl_context + self.ambient_dim = ambient_dim + self.debug = debug + + @memoize_method + def copy_targets_kernel(self): + knl = lp.make_kernel( + """{[dim,i]: + 0<=dim Date: Sun, 25 Jun 2017 21:10:41 -0500 Subject: [PATCH 2/7] Whitespace fix. --- pytential/unregularized.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytential/unregularized.py b/pytential/unregularized.py index c0256aeb..77cac188 100644 --- a/pytential/unregularized.py +++ b/pytential/unregularized.py @@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import six import numpy as np -- GitLab From e78cc1eb9bc470aaf3c5a69283e5dc82287d12fb Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sun, 25 Jun 2017 21:38:07 -0500 Subject: [PATCH 3/7] Fix typo. --- pytential/unregularized.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytential/unregularized.py b/pytential/unregularized.py index 77cac188..4d7f2f92 100644 --- a/pytential/unregularized.py +++ b/pytential/unregularized.py @@ -30,7 +30,7 @@ import six import numpy as np import loopy as lp -from boxtre.tools import DeviceDataRecord +from boxtree.tools import DeviceDataRecord from pytential.source import LayerPotentialSourceBase from pytools import memoize_method -- GitLab From 828860c8f83beeb16a559dc10a0b41bdda8d9d63 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Mon, 26 Jun 2017 23:58:46 -0500 Subject: [PATCH 4/7] Move some FMM preprocessing code to LayerPotentialSourceBase. --- pytential/qbx/__init__.py | 57 +++++---------------------------------- pytential/source.py | 54 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index 0b94f90b..0f751e7e 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -514,56 +514,13 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): strengths = (evaluate(insn.density).with_queue(queue) * self.weights_and_area_elements()) - # {{{ get expansion wrangler - - base_kernel = None - out_kernels = [] - - from sumpy.kernel import AxisTargetDerivativeRemover - for knl in insn.kernels: - candidate_base_kernel = AxisTargetDerivativeRemover()(knl) - - if base_kernel is None: - base_kernel = candidate_base_kernel - else: - assert base_kernel == candidate_base_kernel - out_kernels = tuple(knl for knl in insn.kernels) - - if base_kernel.is_complex_valued or strengths.dtype.kind == "c": - value_dtype = self.complex_dtype - else: - value_dtype = self.real_dtype - - # {{{ build extra_kwargs dictionaries - - # This contains things like the Helmholtz parameter k or - # the normal directions for double layers. - - def reorder_sources(source_array): - if isinstance(source_array, cl.array.Array): - return (source_array - .with_queue(queue) - [geo_data.tree().user_source_ids] - .with_queue(None)) - else: - return source_array - - kernel_extra_kwargs = {} - source_extra_kwargs = {} - - from sumpy.tools import gather_arguments, gather_source_arguments - from pytools.obj_array import with_object_array_or_scalar - for func, var_dict in [ - (gather_arguments, kernel_extra_kwargs), - (gather_source_arguments, source_extra_kwargs), - ]: - for arg in func(out_kernels): - var_dict[arg.name] = with_object_array_or_scalar( - reorder_sources, - evaluate(insn.kernel_arguments[arg.name])) - - # }}} + base_kernel = self.get_fmm_base_kernel(out_kernels) + value_dtype = self.get_fmm_value_dtype(base_kernel, strengths) + kernel_extra_kwargs, source_extra_kwargs = ( + self.get_fmm_expansion_wrangler_extra_kwargs( + queue, out_kernels, geo_data.tree().user_source_ids, + insn.kernel_arguments, evaluate)) wrangler = self.expansion_wrangler_code_container( base_kernel, out_kernels).get_wrangler( @@ -573,8 +530,6 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): source_extra_kwargs=source_extra_kwargs, kernel_extra_kwargs=kernel_extra_kwargs) - # }}} - if len(geo_data.global_qbx_centers()) != geo_data.ncenters: raise NotImplementedError("geometry has centers requiring local QBX") diff --git a/pytential/source.py b/pytential/source.py index 6d78da57..2d2adeb5 100644 --- a/pytential/source.py +++ b/pytential/source.py @@ -206,6 +206,60 @@ class LayerPotentialSourceBase(PotentialSource): return p2p + # {{{ fmm setup helpers + + def get_fmm_base_kernel(self, kernels): + base_kernel = None + + from sumpy.kernel import AxisTargetDerivativeRemover + for knl in kernels: + candidate_base_kernel = AxisTargetDerivativeRemover()(knl) + + if base_kernel is None: + base_kernel = candidate_base_kernel + else: + assert base_kernel == candidate_base_kernel + + return base_kernel + + def get_fmm_value_dtype(self, base_kernel, strengths): + if base_kernel.is_complex_valued or strengths.dtype.kind == "c": + return self.complex_dtype + else: + return self.real_dtype + + def get_fmm_expansion_wrangler_extra_kwargs( + self, queue, out_kernels, tree_user_source_ids, arguments, evaluator): + # This contains things like the Helmholtz parameter k or + # the normal directions for double layers. + + def reorder_sources(source_array): + if isinstance(source_array, cl.array.Array): + return (source_array + .with_queue(queue) + [tree_user_source_ids] + .with_queue(None)) + else: + return source_array + + kernel_extra_kwargs = {} + source_extra_kwargs = {} + + from sumpy.tools import gather_arguments, gather_source_arguments + from pytools.obj_array import with_object_array_or_scalar + for func, var_dict in [ + (gather_arguments, kernel_extra_kwargs), + (gather_source_arguments, source_extra_kwargs), + ]: + for arg in func(out_kernels): + var_dict[arg.name] = with_object_array_or_scalar( + reorder_sources, + evaluator(arguments[arg.name])) + + return kernel_extra_kwargs, source_extra_kwargs + + # }}} + # {{{ weights and area elements @memoize_method -- GitLab From c660125d2d7716f5a687636c148207f899923d83 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Mon, 26 Jun 2017 23:59:05 -0500 Subject: [PATCH 5/7] Finish implementing FMM based execution in UnregularizedLayerPotentialSource. --- pytential/unregularized.py | 117 +++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/pytential/unregularized.py b/pytential/unregularized.py index 4d7f2f92..c28833e1 100644 --- a/pytential/unregularized.py +++ b/pytential/unregularized.py @@ -51,11 +51,14 @@ __doc__ = """ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): """A source discretization for a layer potential discretized with a Nyström method that uses panel-based quadrature and does not modify the kernel. + + .. attribute:: fmm_level_to_order """ def __init__(self, density_discr, fmm_order=None, fmm_level_to_order=None, + expansion_factory=None, # begin undocumented arguments # FIXME default debug=False once everything works debug=True): @@ -72,10 +75,17 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): if fmm_order is not None: def fmm_level_to_order(level): return fmm_order + else: + fmm_level_to_order = False self.density_discr = density_discr self.fmm_level_to_order = fmm_level_to_order + if expansion_factory is None: + from sumpy.expansion import DefaultExpansionFactory + expansion_factory = DefaultExpansionFactory() + self.expansion_factory = expansion_factory + self.debug = debug @property @@ -107,6 +117,11 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): value = evaluate(expr) return with_object_array_or_scalar(lambda x: x, value) + if self.fmm_level_to_order is False: + func = self.exec_compute_potential_insn_direct + else: + func = self.exec_compute_potential_insn_fmm + func = self.exec_compute_potential_insn_direct return func(queue, insn, bound_expr, evaluate_wrapper) @@ -153,6 +168,108 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): return result, [] + # {{{ fmm-based execution + + @memoize_method + def expansion_wrangler_code_container(self, base_kernel, out_kernels): + mpole_expn_class = \ + self.expansion_factory.get_multipole_expansion_class(base_kernel) + local_expn_class = \ + self.expansion_factory.get_local_expansion_class(base_kernel) + + from functools import partial + fmm_mpole_factory = partial(mpole_expn_class, base_kernel) + fmm_local_factory = partial(local_expn_class, base_kernel) + + from sumpy.fmm import SumpyExpansionWrangerCodeContainer + return SumpyExpansionWrangerCodeContainer( + self.cl_context, + fmm_mpole_factory, + fmm_local_factory, + out_kernels) + + @property + @memoize_method + def fmm_geometry_code_container(self): + return _FMMGeometryCodeContainer( + self.cl_context, self.ambient_dim, self.debug) + + def fmm_geometry_data(self, targets): + return _FMMGeometryData( + self, + self.fmm_geometry_code_container, + targets, + self.debug) + + def exec_compute_potential_insn_fmm(self, queue, insn, bound_expr, evaluate): + + # {{{ gather unique target discretizations used + + target_name_to_index = {} + targets = [] + + for o in insn.outputs: + assert o.qbx_forced_limit not in (-1, 1) + + if o.target_name in target_name_to_index: + continue + + target_name_to_index[o.target_name] = len(targets) + targets.append(bound_expr.places[o.target_name]) + + targets = tuple(targets) + + # }}} + + # {{{ get wrangler + + geo_data = self.fmm_geometry_data(targets) + + strengths = (evaluate(insn.density).with_queue(queue) + * self.weights_and_area_elements()) + + out_kernels = tuple(knl for knl in insn.kernels) + base_kernel = self.get_fmm_base_kernel(out_kernels) + value_dtype = self.get_fmm_value_dtype(base_kernel, strengths) + kernel_extra_kwargs, source_extra_kwargs = ( + self.get_fmm_expansion_wrangler_extra_kwargs( + queue, out_kernels, geo_data.tree().user_source_ids, + insn.kernel_arguments, evaluate)) + + wrangler = self.expansion_wrangler_code_container( + out_kernels, base_kernel).get_wrangler( + queue, + geo_data.tree(), + value_dtype, + self.fmm_level_to_order, + source_extra_kwargs=source_extra_kwargs, + kernel_extra_kwargs=kernel_extra_kwargs) + + # }}} + + from boxtree.fmm import drive_fmm + all_potentials_on_every_tgt = drive_fmm( + geo_data.traversal(), wrangler, strengths) + + # {{{ postprocess fmm + + result = [] + + for o in insn.outputs: + target_index = target_name_to_index[o.target_name] + target_slice = slice(*geo_data.target_info().target_discr_starts[ + target_index:target_index+2]) + + result.append( + (o.name, + all_potentials_on_every_tgt[o.kernel_index][target_slice])) + + # }}} + + return result, [] + + # }}} + # }}} -- GitLab From 29ff7b30bfdfd53eb5138ffa03dfab4a18a89b03 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 27 Jun 2017 00:40:07 -0500 Subject: [PATCH 6/7] Get the FMM working; add an off surface eval test. --- pytential/unregularized.py | 21 ++++++++++++------ test/test_layer_pot.py | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/pytential/unregularized.py b/pytential/unregularized.py index c28833e1..9dee265c 100644 --- a/pytential/unregularized.py +++ b/pytential/unregularized.py @@ -56,7 +56,7 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): """ def __init__(self, density_discr, - fmm_order=None, + fmm_order=False, fmm_level_to_order=None, expansion_factory=None, # begin undocumented arguments @@ -68,11 +68,11 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): self.density_discr = density_discr self.debug = debug - if fmm_order is not None and fmm_level_to_order is not None: + if fmm_order is not False and fmm_level_to_order is not None: raise TypeError("may not specify both fmm_order and fmm_level_to_order") if fmm_level_to_order is None: - if fmm_order is not None: + if fmm_order is not False: def fmm_level_to_order(level): return fmm_order else: @@ -122,7 +122,6 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): else: func = self.exec_compute_potential_insn_fmm - func = self.exec_compute_potential_insn_direct return func(queue, insn, bound_expr, evaluate_wrapper) def op_group_features(self, expr): @@ -181,8 +180,8 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): fmm_mpole_factory = partial(mpole_expn_class, base_kernel) fmm_local_factory = partial(local_expn_class, base_kernel) - from sumpy.fmm import SumpyExpansionWrangerCodeContainer - return SumpyExpansionWrangerCodeContainer( + from sumpy.fmm import SumpyExpansionWranglerCodeContainer + return SumpyExpansionWranglerCodeContainer( self.cl_context, fmm_mpole_factory, fmm_local_factory, @@ -237,7 +236,7 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): insn.kernel_arguments, evaluate)) wrangler = self.expansion_wrangler_code_container( - out_kernels, base_kernel).get_wrangler( + base_kernel, out_kernels).get_wrangler( queue, geo_data.tree(), value_dtype, @@ -317,7 +316,13 @@ class _FMMGeometryCodeContainer(object): class _TargetInfo(DeviceDataRecord): """ .. attribute:: targets + + Shape: ``[dim,ntargets]`` + .. attribute:: target_discr_starts + + Shape: ``[ndiscrs+1]`` + .. attribute:: ntargets """ @@ -396,6 +401,8 @@ class _FMMGeometryData(object): target_discr_starts.append(ntargets) ntargets += target_discr.nnodes + target_discr_starts.append(ntargets) + targets = cl.array.empty( self.cl_context, (lpot_src.ambient_dim, ntargets), diff --git a/test/test_layer_pot.py b/test/test_layer_pot.py index 34d11d3f..6dc8e5a9 100644 --- a/test/test_layer_pot.py +++ b/test/test_layer_pot.py @@ -1321,6 +1321,50 @@ def test_unregularized_with_ones_kernel(ctx_getter): assert np.allclose(result_self.get(), 2 * np.pi) assert np.allclose(result_nonself.get(), 2 * np.pi) + +def test_unregularized_off_surface_fmm_vs_direct(ctx_getter): + cl_ctx = ctx_getter() + queue = cl.CommandQueue(cl_ctx) + + nelements = 300 + target_order = 8 + fmm_order = 4 + + mesh = make_curve_mesh(WobblyCircle.random(8, seed=30), + np.linspace(0, 1, nelements+1), + target_order) + + from pytential.unregularized import UnregularizedLayerPotentialSource + from meshmode.discretization import Discretization + from meshmode.discretization.poly_element import \ + InterpolatoryQuadratureSimplexGroupFactory + + density_discr = Discretization( + cl_ctx, mesh, InterpolatoryQuadratureSimplexGroupFactory(target_order)) + direct = UnregularizedLayerPotentialSource( + density_discr, + fmm_order=False, + ) + fmm = direct.copy(fmm_level_to_order=lambda _: fmm_order) + + sigma = density_discr.zeros(queue) + 1 + + fplot = FieldPlotter(np.zeros(2), extent=5, npoints=100) + from pytential.target import PointsTarget + ptarget = PointsTarget(fplot.points) + from sumpy.kernel import LaplaceKernel + + op = sym.D(LaplaceKernel(2), sym.var("sigma"), qbx_forced_limit=None) + + direct_fld_in_vol = bind((direct, ptarget), op)(queue, sigma=sigma) + fmm_fld_in_vol = bind((fmm, ptarget), op)(queue, sigma=sigma) + + err = cl.clmath.fabs(fmm_fld_in_vol - direct_fld_in_vol) + + linf_err = cl.array.max(err).get() + print("l_inf error:", linf_err) + assert linf_err < 5e-3 + # }}} -- GitLab From 75c6a79efceaae831fe618f35bccdceeda741cf7 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 6 Jul 2017 01:12:00 -0500 Subject: [PATCH 7/7] Change some terminology: * base_kernel -> fmm_kernel * value_dtype -> output_and_expansion_dtype --- pytential/qbx/__init__.py | 21 +++++++++++---------- pytential/source.py | 16 ++++++++-------- pytential/unregularized.py | 19 ++++++++++--------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index ee5be757..50fb8e58 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -441,16 +441,16 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): # {{{ fmm-based execution @memoize_method - def expansion_wrangler_code_container(self, base_kernel, out_kernels): + def expansion_wrangler_code_container(self, fmm_kernel, out_kernels): mpole_expn_class = \ - self.expansion_factory.get_multipole_expansion_class(base_kernel) + self.expansion_factory.get_multipole_expansion_class(fmm_kernel) local_expn_class = \ - self.expansion_factory.get_local_expansion_class(base_kernel) + self.expansion_factory.get_local_expansion_class(fmm_kernel) from functools import partial - fmm_mpole_factory = partial(mpole_expn_class, base_kernel) - fmm_local_factory = partial(local_expn_class, base_kernel) - qbx_local_factory = partial(local_expn_class, base_kernel) + fmm_mpole_factory = partial(mpole_expn_class, fmm_kernel) + fmm_local_factory = partial(local_expn_class, fmm_kernel) + qbx_local_factory = partial(local_expn_class, fmm_kernel) if self.fmm_backend == "sumpy": from pytential.qbx.fmm import \ @@ -515,16 +515,17 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): * self.weights_and_area_elements()) out_kernels = tuple(knl for knl in insn.kernels) - base_kernel = self.get_fmm_base_kernel(out_kernels) - value_dtype = self.get_fmm_value_dtype(base_kernel, strengths) + fmm_kernel = self.get_fmm_kernel(out_kernels) + output_and_expansion_dtype = ( + self.get_fmm_output_and_expansion_dtype(fmm_kernel, strengths)) kernel_extra_kwargs, source_extra_kwargs = ( self.get_fmm_expansion_wrangler_extra_kwargs( queue, out_kernels, geo_data.tree().user_source_ids, insn.kernel_arguments, evaluate)) wrangler = self.expansion_wrangler_code_container( - base_kernel, out_kernels).get_wrangler( - queue, geo_data, value_dtype, + fmm_kernel, out_kernels).get_wrangler( + queue, geo_data, output_and_expansion_dtype, self.qbx_order, self.fmm_level_to_order, source_extra_kwargs=source_extra_kwargs, diff --git a/pytential/source.py b/pytential/source.py index 2d2adeb5..75d0db70 100644 --- a/pytential/source.py +++ b/pytential/source.py @@ -208,21 +208,21 @@ class LayerPotentialSourceBase(PotentialSource): # {{{ fmm setup helpers - def get_fmm_base_kernel(self, kernels): - base_kernel = None + def get_fmm_kernel(self, kernels): + fmm_kernel = None from sumpy.kernel import AxisTargetDerivativeRemover for knl in kernels: - candidate_base_kernel = AxisTargetDerivativeRemover()(knl) + candidate_fmm_kernel = AxisTargetDerivativeRemover()(knl) - if base_kernel is None: - base_kernel = candidate_base_kernel + if fmm_kernel is None: + fmm_kernel = candidate_fmm_kernel else: - assert base_kernel == candidate_base_kernel + assert fmm_kernel == candidate_fmm_kernel - return base_kernel + return fmm_kernel - def get_fmm_value_dtype(self, base_kernel, strengths): + def get_fmm_output_and_expansion_dtype(self, base_kernel, strengths): if base_kernel.is_complex_valued or strengths.dtype.kind == "c": return self.complex_dtype else: diff --git a/pytential/unregularized.py b/pytential/unregularized.py index 9dee265c..fcd5298e 100644 --- a/pytential/unregularized.py +++ b/pytential/unregularized.py @@ -170,15 +170,15 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): # {{{ fmm-based execution @memoize_method - def expansion_wrangler_code_container(self, base_kernel, out_kernels): + def expansion_wrangler_code_container(self, fmm_kernel, out_kernels): mpole_expn_class = \ - self.expansion_factory.get_multipole_expansion_class(base_kernel) + self.expansion_factory.get_multipole_expansion_class(fmm_kernel) local_expn_class = \ - self.expansion_factory.get_local_expansion_class(base_kernel) + self.expansion_factory.get_local_expansion_class(fmm_kernel) from functools import partial - fmm_mpole_factory = partial(mpole_expn_class, base_kernel) - fmm_local_factory = partial(local_expn_class, base_kernel) + fmm_mpole_factory = partial(mpole_expn_class, fmm_kernel) + fmm_local_factory = partial(local_expn_class, fmm_kernel) from sumpy.fmm import SumpyExpansionWranglerCodeContainer return SumpyExpansionWranglerCodeContainer( @@ -228,18 +228,19 @@ class UnregularizedLayerPotentialSource(LayerPotentialSourceBase): * self.weights_and_area_elements()) out_kernels = tuple(knl for knl in insn.kernels) - base_kernel = self.get_fmm_base_kernel(out_kernels) - value_dtype = self.get_fmm_value_dtype(base_kernel, strengths) + fmm_kernel = self.get_fmm_kernel(out_kernels) + output_and_expansion_dtype = ( + self.get_fmm_output_and_expansion_dtype(fmm_kernel, strengths)) kernel_extra_kwargs, source_extra_kwargs = ( self.get_fmm_expansion_wrangler_extra_kwargs( queue, out_kernels, geo_data.tree().user_source_ids, insn.kernel_arguments, evaluate)) wrangler = self.expansion_wrangler_code_container( - base_kernel, out_kernels).get_wrangler( + fmm_kernel, out_kernels).get_wrangler( queue, geo_data.tree(), - value_dtype, + output_and_expansion_dtype, self.fmm_level_to_order, source_extra_kwargs=source_extra_kwargs, kernel_extra_kwargs=kernel_extra_kwargs) -- GitLab