From 48bdc5ef3fd17272b5824221d17337ec30780b9d Mon Sep 17 00:00:00 2001
From: Isuru Fernando <idf2@illinois.edu>
Date: Sun, 27 Mar 2022 22:28:35 -0700
Subject: [PATCH] Direct loopy kernel (#95)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* post process kernel

* Fix postprocesslocal

* use preprocess/postprocess everywhere

* Fix event management for m2l

* Use a loopy kernel

* return a loopy kernel

* fix formatting

* reduce diff

* Fix typo

* Fix typo

* use_preprocessing_for_m2l -> use_fft_for_m2l

* supports_optimized_m2l -> supports_translation_classes

* restore non fft optimized code path

* events -> timing_events

* More descriptive exception message

* use lp.Assignment instead of strings

* Fix typo

* Add a comment about symbolic sum

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* Fix error message

* split too long line

* Make use_preprocessing_for_m2l an option

* Add an comment about the options

* fix syntax errors

* add missing defines

* Fix hashing

* Fix returning nexprs

* fix use_preprocessing_for_m2l

* Fix ncoeff_src and ncoeff_tgt

* fix use_preprocessing_for_m2l for bessel based

* Fix logic

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 sumpy/__init__.py        |   5 +-
 sumpy/e2e.py             | 330 +++++++++++++++++++++++----------------
 sumpy/expansion/local.py | 237 ++++++++++++++++++++--------
 sumpy/fmm.py             | 112 +++++++++----
 test/test_fmm.py         |   2 +-
 5 files changed, 458 insertions(+), 228 deletions(-)

diff --git a/sumpy/__init__.py b/sumpy/__init__.py
index 30b0cd66..b39ead2e 100644
--- a/sumpy/__init__.py
+++ b/sumpy/__init__.py
@@ -26,7 +26,8 @@ from sumpy.p2e import P2EFromSingleBox, P2EFromCSR
 from sumpy.e2p import E2PFromSingleBox, E2PFromCSR
 from sumpy.e2e import (E2EFromCSR, E2EFromChildren, E2EFromParent,
     M2LUsingTranslationClassesDependentData,
-    M2LGenerateTranslationClassesDependentData, M2LPreprocessMultipole)
+    M2LGenerateTranslationClassesDependentData, M2LPreprocessMultipole,
+    M2LPostprocessLocal)
 from sumpy.version import VERSION_TEXT
 from pytools.persistent_dict import WriteOncePersistentDict
 
@@ -37,7 +38,7 @@ __all__ = [
     "E2EFromCSR", "E2EFromChildren", "E2EFromParent",
     "M2LUsingTranslationClassesDependentData",
     "M2LGenerateTranslationClassesDependentData",
-    "M2LPreprocessMultipole"]
+    "M2LPreprocessMultipole", "M2LPostprocessLocal"]
 
 
 code_cache = WriteOncePersistentDict("sumpy-code-cache-v6-"+VERSION_TEXT)
diff --git a/sumpy/e2e.py b/sumpy/e2e.py
index 7774b62a..550cb173 100644
--- a/sumpy/e2e.py
+++ b/sumpy/e2e.py
@@ -23,6 +23,7 @@ THE SOFTWARE.
 import numpy as np
 import loopy as lp
 import sumpy.symbolic as sym
+import pymbolic
 
 from loopy.version import MOST_RECENT_LANGUAGE_VERSION
 from sumpy.tools import KernelCacheWrapper, to_complex_dtype
@@ -84,8 +85,6 @@ class E2EBase(KernelCacheWrapper):
         self.tgt_expansion = tgt_expansion
         self.name = name or self.default_name
         self.device = device
-        self.use_preprocessing_for_m2l = getattr(self.tgt_expansion,
-            "use_preprocessing_for_m2l", False)
 
         if src_expansion.dim != tgt_expansion.dim:
             raise ValueError("source and target expansions must have "
@@ -272,13 +271,6 @@ class E2EFromCSR(E2EBase):
 
         return knl
 
-    def get_cache_key(self):
-        return (
-                type(self).__name__,
-                self.src_expansion,
-                self.tgt_expansion,
-        )
-
     def __call__(self, queue, **kwargs):
         """
         :arg src_expansions:
@@ -320,22 +312,18 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                 self.tgt_expansion.m2l_translation_classes_dependent_ndata(
                         self.src_expansion)
         m2l_translation_classes_dependent_data = \
-                    [sym.Symbol("m2l_translation_classes_dependent_expr%d" % i)
+                    [sym.Symbol("data%d" % i)
                 for i in range(m2l_translation_classes_dependent_ndata)]
 
-        if self.use_preprocessing_for_m2l:
-            ncoeff_src = self.tgt_expansion.m2l_preprocess_multipole_nexprs(
-                self.src_expansion)
-        else:
-            ncoeff_src = len(self.src_expansion)
+        ncoeff_src = len(self.src_expansion)
 
-        src_coeff_exprs = [sym.Symbol("src_coeff%d" % i)
+        src_coeff_exprs = [sym.Symbol("src_coeffs%d" % i)
                 for i in range(ncoeff_src)]
 
         from sumpy.assignment_collection import SymbolicAssignmentCollection
         sac = SymbolicAssignmentCollection()
         tgt_coeff_names = [
-                sac.assign_unique("coeff%d" % i, coeff_i)
+                sac.assign_unique("tgt_coeff%d" % i, coeff_i)
                 for i, coeff_i in enumerate(
                     self.tgt_expansion.translate_from(
                         self.src_expansion, src_coeff_exprs, src_rscale,
@@ -348,87 +336,65 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
         from sumpy.codegen import to_loopy_insns
         return to_loopy_insns(
                 sac.assignments.items(),
-                vector_names=set(["d"]),
+                vector_names=set(["d", "src_coeffs", "data"]),
                 pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()],
                 retain_names=tgt_coeff_names,
                 complex_dtype=to_complex_dtype(result_dtype),
                 )
 
-    def get_postprocess_loopy_insns(self, result_dtype):
-        """Loopy instructions that happen only once for each target box.
-
-        :arg result_dtype: A numpy dtype for the result. This is important
-            because depending on the input to the M2L being real or not
-            the code needs to be different. If the input was real, the
-            result should be the real part of the complex values from the
-            inverse FFT. In the input was complex, the result matches the
-            values from the inverse FFT.
-        """
+    def get_inner_loopy_kernel(self, result_dtype):
+        try:
+            return self.tgt_expansion.loopy_translate_from(
+                self.src_expansion)
+        except NotImplementedError:
+            pass
 
-        ncoeff_tgt = len(self.tgt_expansion)
-        m2l_translation_classes_dependent_ndata = \
-                self.tgt_expansion.m2l_translation_classes_dependent_ndata(
+        ndata = self.tgt_expansion.m2l_translation_classes_dependent_ndata(
                         self.src_expansion)
-
-        if self.use_preprocessing_for_m2l:
-            ncoeff_tgt = m2l_translation_classes_dependent_ndata
-
-        from sumpy.assignment_collection import SymbolicAssignmentCollection
-        sac = SymbolicAssignmentCollection()
-
-        tgt_coeff_exprs = [
-            sym.Symbol("coeff_sum%d" % i) for i in range(ncoeff_tgt)
-        ]
-
-        src_rscale = sym.Symbol("src_rscale")
-        tgt_rscale = sym.Symbol("tgt_rscale")
-
-        if self.use_preprocessing_for_m2l:
-            tgt_coeff_post_exprs = self.tgt_expansion.m2l_postprocess_local_exprs(
-                self.src_expansion, tgt_coeff_exprs, src_rscale, tgt_rscale,
-                sac=sac)
-            if result_dtype in (np.float32, np.float64):
-                real_func = sym.Function("real")
-                tgt_coeff_post_exprs = [real_func(expr) for expr in
-                    tgt_coeff_post_exprs]
+        if self.tgt_expansion.use_preprocessing_for_m2l:
+            ncoeff_src = self.tgt_expansion.m2l_preprocess_multipole_nexprs(
+                    self.src_expansion)
+            ncoeff_tgt = self.tgt_expansion.m2l_postprocess_local_nexprs(
+                    self.src_expansion)
         else:
-            tgt_coeff_post_exprs = tgt_coeff_exprs
-
-        tgt_coeff_post_names = [
-            sac.assign_unique("coeff_post%d" % i, coeff)
-            for i, coeff in enumerate(tgt_coeff_post_exprs)
-        ]
-
-        sac.run_global_cse()
-
-        from sumpy.codegen import to_loopy_insns
-        insns = to_loopy_insns(
-                sac.assignments.items(),
-                vector_names=set(["d"]),
-                pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()],
-                retain_names=tgt_coeff_post_names,
-                complex_dtype=to_complex_dtype(result_dtype),
-                )
-        return insns
+            ncoeff_src = len(self.src_expansion)
+            ncoeff_tgt = len(self.tgt_expansion)
+
+        domains = []
+        insns = self.get_translation_loopy_insns(result_dtype)
+        coeff = pymbolic.var("coeff")
+        for i in range(ncoeff_tgt):
+            expr = pymbolic.var(f"tgt_coeff{i}")
+            insn = lp.Assignment(assignee=coeff[i],
+                    expression=coeff[i] + expr)
+            insns.append(insn)
+
+        return lp.make_function(domains, insns,
+                        kernel_data=[
+                            lp.GlobalArg("coeff", shape=(ncoeff_tgt,),
+                                is_output=True, is_input=True),
+                            lp.GlobalArg("src_coeffs", shape=(ncoeff_src,)),
+                            lp.GlobalArg("data", shape=(ndata,)),
+                            lp.ValueArg("src_rscale"),
+                            lp.ValueArg("tgt_rscale"),
+                            ...],
+                        name="e2e",
+                        lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+                        )
 
     def get_kernel(self, result_dtype):
         m2l_translation_classes_dependent_ndata = \
                 self.tgt_expansion.m2l_translation_classes_dependent_ndata(
                         self.src_expansion)
 
-        if self.use_preprocessing_for_m2l:
-            # number of expressions given as input to M2L after preprocessing
+        if self.tgt_expansion.use_preprocessing_for_m2l:
             ncoeff_src = self.tgt_expansion.m2l_preprocess_multipole_nexprs(
                     self.src_expansion)
-            # number of expressions given as input to postprocessing
-            ncoeff_tgt_before_postprocess = \
-                    self.tgt_expansion.m2l_postprocess_local_nexprs(
-                        self.src_expansion)
+            ncoeff_tgt = self.tgt_expansion.m2l_postprocess_local_nexprs(
+                    self.src_expansion)
         else:
             ncoeff_src = len(self.src_expansion)
-            ncoeff_tgt_before_postprocess = len(self.tgt_expansion)
-
-        ncoeff_tgt = len(self.tgt_expansion)
+            ncoeff_tgt = len(self.tgt_expansion)
 
         # To clarify terminology:
         #
@@ -437,53 +403,44 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
         #
         # (same for itgt_box, tgt_ibox)
 
+        translation_knl = self.get_inner_loopy_kernel(result_dtype)
+
         from sumpy.tools import gather_loopy_arguments
         loopy_knl = lp.make_kernel(
                 [
                     "{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
                     "{[isrc_box]: isrc_start<=isrc_box<isrc_stop}",
-                    "{[idim]: 0<=idim<dim}",
-                    ],
+                    "{[icoeff_tgt]: 0<=icoeff_tgt<ncoeff_tgt}",
+                    "{[icoeff_src]: 0<=icoeff_src<ncoeff_src}",
+                    "{[idep]: 0<=idep<m2l_translation_classes_dependent_ndata}",
+                ],
                 ["""
                 for itgt_box
                     <> tgt_ibox = target_boxes[itgt_box]
-                    <> tgt_center[idim] = centers[idim, tgt_ibox]
                     <> isrc_start = src_box_starts[itgt_box]
                     <> isrc_stop = src_box_starts[itgt_box+1]
-
+                    for icoeff_tgt
+                        <> coeffs[icoeff_tgt] = 0 {id=init_coeffs, dup=icoeff_tgt}
+                    end
                     for isrc_box
                         <> src_ibox = src_box_lists[isrc_box] \
                                 {id=read_src_ibox}
-                        <> src_center[idim] = centers[idim, src_ibox] {dup=idim}
-                        <> d[idim] = tgt_center[idim] - src_center[idim] \
-                            {dup=idim}
                         <> translation_class = \
                                 m2l_translation_classes_lists[isrc_box]
                         <> translation_class_rel = translation_class - \
                                                     translation_classes_level_start
-                        """] + ["""
-                        <> m2l_translation_classes_dependent_expr{idx} = \
-                            m2l_translation_classes_dependent_data[ \
-                                translation_class_rel, {idx}]
-                        """.format(idx=idx) for idx in range(
-                            m2l_translation_classes_dependent_ndata)] + ["""
-                        <> src_coeff{coeffidx} = \
-                            src_expansions[src_ibox - src_base_ibox, {coeffidx}] \
-                            {{dep=read_src_ibox}}
-                        """.format(coeffidx=i) for i in range(ncoeff_src)] + [
-
-                        ] + self.get_translation_loopy_insns(result_dtype) + ["""
+                        [icoeff_tgt]: coeffs[icoeff_tgt] = e2e(
+                            [icoeff_tgt]: coeffs[icoeff_tgt],
+                            [icoeff_src]: src_expansions[src_ibox - src_base_ibox,
+                                icoeff_src],
+                            [idep]: m2l_translation_classes_dependent_data[
+                                translation_class_rel, idep],
+                            src_rscale,
+                            tgt_rscale,
+                            )  {dep=init_coeffs,id=update_coeffs}
                     end
-
-                    """] + ["""
-                    <> coeff_sum{coeffidx} = \
-                        simul_reduce(sum, isrc_box, coeff{coeffidx})
-                    """.format(coeffidx=i) for i in
-                        range(ncoeff_tgt_before_postprocess)] + [
-                    ] + self.get_postprocess_loopy_insns(result_dtype) + [f"""
-                    tgt_expansions[tgt_ibox - tgt_base_ibox, {coeffidx}] = \
-                            coeff_post{coeffidx} {{id_prefix=write_expn}}
-                    """ for coeffidx in range(ncoeff_tgt)] + ["""
+                    tgt_expansions[tgt_ibox - tgt_base_ibox, icoeff_tgt] = \
+                            coeffs[icoeff_tgt] {dep=update_coeffs, dup=icoeff_tgt}
                 end
                 """],
                 [
@@ -498,7 +455,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                     lp.GlobalArg("src_expansions", None,
                         shape=("nsrc_level_boxes", ncoeff_src), offset=lp.auto),
                     lp.GlobalArg("tgt_expansions", None,
-                        shape=("ntgt_level_boxes", ncoeff_tgt), offset=lp.auto),
+                        shape=("ntgt_level_boxes", ncoeff_tgt),
+                        offset=lp.auto),
                     lp.ValueArg("translation_classes_level_start",
                         np.int32),
                     lp.GlobalArg("m2l_translation_classes_dependent_data", None,
@@ -515,26 +473,26 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                                             self.tgt_expansion]),
                 name=self.name,
                 assumptions="ntgt_boxes>=1",
-                silenced_warnings="write_race(write_expn*)",
                 default_offset=lp.auto,
                 fixed_parameters=dict(dim=self.dim,
                         m2l_translation_classes_dependent_ndata=(
-                            m2l_translation_classes_dependent_ndata)),
+                            m2l_translation_classes_dependent_ndata),
+                        ncoeff_tgt=ncoeff_tgt,
+                        ncoeff_src=ncoeff_src),
                 lang_version=MOST_RECENT_LANGUAGE_VERSION
                 )
 
+        loopy_knl = lp.merge([translation_knl, loopy_knl])
+        loopy_knl = lp.inline_callable_kernel(loopy_knl, "e2e")
+
         for knl in [self.src_expansion.kernel, self.tgt_expansion.kernel]:
             loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
 
-        loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
-        loopy_knl = lp.set_options(loopy_knl,
-                enforce_variable_access_ordered="no_check")
-
         return loopy_knl
 
     def get_optimized_kernel(self, result_dtype):
-        # FIXME
         knl = self.get_kernel(result_dtype)
+        # FIXME
         knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0")
 
         return knl
@@ -553,17 +511,12 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
         # meaningfully inferred. Make the type of rscale explicit.
         src_rscale = centers.dtype.type(kwargs.pop("src_rscale"))
         tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale"))
+        src_expansions = kwargs.pop("src_expansions")
 
-        if "tgt_expansions" in kwargs:
-            tgt_expansions = kwargs["tgt_expansions"]
-            result_dtype = tgt_expansions.dtype
-        else:
-            src_expansions = kwargs["src_expansions"]
-            result_dtype = src_expansions.dtype
-
-        knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
+        knl = self.get_cached_optimized_kernel(result_dtype=src_expansions.dtype)
 
         return knl(queue,
+                src_expansions=src_expansions,
                 centers=centers,
                 src_rscale=src_rscale, tgt_rscale=tgt_rscale,
                 **kwargs)
@@ -788,6 +741,118 @@ class M2LPreprocessMultipole(E2EBase):
 # }}}
 
 
+# {{{ M2LPostprocessLocal
+
+class M2LPostprocessLocal(E2EBase):
+    """Postprocesses locals expansions for accelerated M2L"""
+
+    default_name = "m2l_postprocess_local"
+
+    def get_loopy_insns(self, result_dtype):
+        ncoeffs_before_postprocessing = \
+            self.tgt_expansion.m2l_postprocess_local_nexprs(self.tgt_expansion)
+
+        tgt_coeff_exprs_before_postprocessing = [
+            sym.Symbol("tgt_coeff_before_postprocessing%d" % i)
+            for i in range(ncoeffs_before_postprocessing)]
+
+        src_rscale = sym.Symbol("src_rscale")
+        tgt_rscale = sym.Symbol("tgt_rscale")
+
+        from sumpy.assignment_collection import SymbolicAssignmentCollection
+        sac = SymbolicAssignmentCollection()
+
+        tgt_coeff_exprs = self.tgt_expansion.m2l_postprocess_local_exprs(
+            self.tgt_expansion, tgt_coeff_exprs_before_postprocessing,
+            sac=sac, src_rscale=src_rscale, tgt_rscale=tgt_rscale)
+
+        if result_dtype in (np.float32, np.float64):
+            real_func = sym.Function("real")
+            tgt_coeff_exprs = [real_func(expr) for expr in
+                    tgt_coeff_exprs]
+
+        tgt_coeff_names = [
+            sac.assign_unique("tgt_coeff%d" % i, coeff_i)
+            for i, coeff_i in enumerate(tgt_coeff_exprs)]
+
+        sac.run_global_cse()
+
+        from sumpy.codegen import to_loopy_insns
+        return to_loopy_insns(
+                sac.assignments.items(),
+                vector_names=set(["d"]),
+                pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()],
+                retain_names=tgt_coeff_names,
+                complex_dtype=to_complex_dtype(result_dtype),
+                )
+
+    def get_kernel(self, result_dtype):
+        ntgt_coeffs = len(self.tgt_expansion)
+        ntgt_coeffs_before_postprocessing = \
+            self.tgt_expansion.m2l_postprocess_local_nexprs(self.tgt_expansion)
+        from sumpy.tools import gather_loopy_arguments
+        loopy_knl = lp.make_kernel(
+                [
+                    "{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
+                    ],
+                ["""
+                for itgt_box
+                """] + ["""
+                    <> tgt_coeff_before_postprocessing{idx} = \
+                            tgt_expansions_before_postprocessing[itgt_box, {idx}]
+                """.format(idx=i) for i in range(
+                    ntgt_coeffs_before_postprocessing)]
+                + self.get_loopy_insns(result_dtype) + ["""
+                    tgt_expansions[itgt_box, {idx}] = \
+                        tgt_coeff{idx}
+                    """.format(idx=i) for i in range(ntgt_coeffs)] + ["""
+                end
+                """],
+                [
+                    lp.ValueArg("ntgt_boxes", np.int32),
+                    lp.ValueArg("src_rscale", None),
+                    lp.ValueArg("tgt_rscale", None),
+                    lp.GlobalArg("tgt_expansions", result_dtype,
+                        shape=("ntgt_boxes", ntgt_coeffs), offset=lp.auto),
+                    lp.GlobalArg("tgt_expansions_before_postprocessing", None,
+                        shape=("ntgt_boxes", ntgt_coeffs_before_postprocessing),
+                        offset=lp.auto),
+                    "..."
+                ] + gather_loopy_arguments([self.src_expansion, self.tgt_expansion]),
+                name=self.name,
+                assumptions="ntgt_boxes>=1",
+                default_offset=lp.auto,
+                fixed_parameters=dict(dim=self.dim),
+                lang_version=MOST_RECENT_LANGUAGE_VERSION
+                )
+
+        for expn in [self.src_expansion.kernel, self.tgt_expansion.kernel]:
+            loopy_knl = expn.prepare_loopy_kernel(loopy_knl)
+
+        loopy_knl = lp.set_options(loopy_knl,
+                enforce_variable_access_ordered="no_check")
+        return loopy_knl
+
+    def get_optimized_kernel(self, result_dtype):
+        # FIXME
+        knl = self.get_kernel(result_dtype)
+        knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0")
+        return knl
+
+    def __call__(self, queue, **kwargs):
+        """
+        :arg tgt_expansions
+        :arg tgt_expansions_before_postprocessing
+        """
+        tgt_expansions = kwargs.pop("tgt_expansions")
+        result_dtype = tgt_expansions.dtype
+        knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
+        return knl(queue,
+            tgt_expansions=tgt_expansions, **kwargs)
+
+# }}}
+
+
 # {{{ translation from a box's children
 
 class E2EFromChildren(E2EBase):
@@ -796,10 +861,11 @@ class E2EFromChildren(E2EBase):
     def get_kernel(self):
         if self.src_expansion is not self.tgt_expansion:
             raise RuntimeError("%s requires that the source "
-                    "and target expansion are the same object"
-                    % type(self).__name__)
+                   "and target expansion are the same object"
+                   % type(self).__name__)
 
-        ncoeffs = len(self.src_expansion)
+        ncoeffs_src = len(self.src_expansion)
+        ncoeffs_tgt = len(self.tgt_expansion)
 
         # To clarify terminology:
         #
@@ -841,14 +907,14 @@ class E2EFromChildren(E2EBase):
                             <> src_coeff{i} = \
                                 src_expansions[src_ibox - src_base_ibox, {i}] \
                                 {{id_prefix=read_coeff,dep=read_src_ibox}}
-                            """.format(i=i) for i in range(ncoeffs)] + [
+                            """.format(i=i) for i in range(ncoeffs_src)] + [
                             ] + loopy_insns + ["""
                             tgt_expansions[tgt_ibox - tgt_base_ibox, {i}] = \
                                 tgt_expansions[tgt_ibox - tgt_base_ibox, {i}] \
                                 + coeff{i} \
                                 {{id_prefix=write_expn,dep=compute_coeff*,
                                     nosync=read_coeff*}}
-                            """.format(i=i) for i in range(ncoeffs)] + ["""
+                            """.format(i=i) for i in range(ncoeffs_tgt)] + ["""
                         end
                     end
                 end
@@ -861,9 +927,9 @@ class E2EFromChildren(E2EBase):
                     lp.GlobalArg("box_child_ids", None,
                         shape="nchildren, aligned_nboxes"),
                     lp.GlobalArg("tgt_expansions", None,
-                        shape=("ntgt_level_boxes", ncoeffs), offset=lp.auto),
+                        shape=("ntgt_level_boxes", ncoeffs_tgt), offset=lp.auto),
                     lp.GlobalArg("src_expansions", None,
-                        shape=("nsrc_level_boxes", ncoeffs), offset=lp.auto),
+                        shape=("nsrc_level_boxes", ncoeffs_src), offset=lp.auto),
                     lp.ValueArg("src_base_ibox,tgt_base_ibox", np.int32),
                     lp.ValueArg("ntgt_level_boxes,nsrc_level_boxes", np.int32),
                     lp.ValueArg("aligned_nboxes", np.int32),
diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py
index abdd0410..493e7045 100644
--- a/sumpy/expansion/local.py
+++ b/sumpy/expansion/local.py
@@ -29,6 +29,8 @@ from sumpy.expansion import (
 from sumpy.tools import mi_increment_axis, matvec_toeplitz_upper_triangular
 from pytools import single_valued
 from typing import Tuple, Any
+import pymbolic
+import loopy as lp
 
 import logging
 logger = logging.getLogger(__name__)
@@ -50,6 +52,7 @@ class LocalExpansionBase(ExpansionBase):
     .. attribute:: kernel
     .. attribute:: order
     .. attribute:: use_rscale
+    .. attribute:: use_fft_for_m2l
     .. attribute:: use_preprocessing_for_m2l
 
     .. automethod:: m2l_translation_classes_dependent_data
@@ -60,19 +63,26 @@ class LocalExpansionBase(ExpansionBase):
     .. automethod:: m2l_postprocess_local_nexprs
     .. automethod:: translate_from
     """
-    init_arg_names = ("kernel", "order", "use_rscale", "use_preprocessing_for_m2l")
+    init_arg_names = ("kernel", "order", "use_rscale", "use_fft_for_m2l",
+            "use_preprocessing_for_m2l")
 
     def __init__(self, kernel, order, use_rscale=None,
-            use_preprocessing_for_m2l=False):
+            use_fft_for_m2l=False, use_preprocessing_for_m2l=None):
         super().__init__(kernel, order, use_rscale)
-        self.use_preprocessing_for_m2l = use_preprocessing_for_m2l
+        self.use_fft_for_m2l = use_fft_for_m2l
+        if use_preprocessing_for_m2l is None:
+            self.use_preprocessing_for_m2l = use_fft_for_m2l
+        else:
+            self.use_preprocessing_for_m2l = use_preprocessing_for_m2l
 
     def with_kernel(self, kernel):
         return type(self)(kernel, self.order, self.use_rscale,
+                use_fft_for_m2l=self.use_fft_for_m2l,
                 use_preprocessing_for_m2l=self.use_preprocessing_for_m2l)
 
     def update_persistent_hash(self, key_hash, key_builder):
         super().update_persistent_hash(key_hash, key_builder)
+        key_builder.rec(key_hash, self.use_fft_for_m2l)
         key_builder.rec(key_hash, self.use_preprocessing_for_m2l)
 
     def __eq__(self, other):
@@ -81,6 +91,7 @@ class LocalExpansionBase(ExpansionBase):
             and self.kernel == other.kernel
             and self.order == other.order
             and self.use_rscale == other.use_rscale
+            and self.use_fft_for_m2l == other.use_fft_for_m2l
             and self.use_preprocessing_for_m2l == other.use_preprocessing_for_m2l
         )
 
@@ -308,13 +319,10 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
     def m2l_translation_classes_dependent_ndata(self, src_expansion):
         """Returns number of expressions in M2L global precomputation step.
         """
-        mis_with_dummy_rows, mis_without_dummy_rows, _ = \
+        mis_with_dummy_rows, _, _ = \
             self._m2l_translation_classes_dependent_data_mis(src_expansion)
 
-        if self.use_preprocessing_for_m2l:
-            return len(mis_with_dummy_rows)
-        else:
-            return len(mis_without_dummy_rows)
+        return len(mis_with_dummy_rows)
 
     def _m2l_translation_classes_dependent_data_mis(self, src_expansion):
         """We would like to compute the M2L by way of a circulant matrix below.
@@ -431,12 +439,12 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
             vector[i] = add_to_sac(sac,
                         vector_full[srcplusderiv_ident_to_index[term]])
 
-        if self.use_preprocessing_for_m2l:
-            # Add zero values needed to make the translation matrix circulant
-            derivatives_full = [0]*len(circulant_matrix_mis)
-            for expr, mi in zip(vector, needed_vector_terms):
-                derivatives_full[circulant_matrix_ident_to_index[mi]] = expr
+        # Add zero values needed to make the translation matrix circulant
+        derivatives_full = [0]*len(circulant_matrix_mis)
+        for expr, mi in zip(vector, needed_vector_terms):
+            derivatives_full[circulant_matrix_ident_to_index[mi]] = expr
 
+        if self.use_fft_for_m2l:
             # Note that the matrix we have now is a mirror image of a
             # circulant matrix. We reverse the first column to get the
             # first column for the circulant matrix and then finally
@@ -444,7 +452,7 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
             # matrix.
             return fft(list(reversed(derivatives_full)), sac=sac)
 
-        return vector
+        return derivatives_full
 
     def m2l_preprocess_multipole_exprs(self, src_expansion, src_coeff_exprs, sac,
             src_rscale):
@@ -461,15 +469,16 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
             input_vector[circulant_matrix_ident_to_index[term]] = \
                     add_to_sac(sac, coeff)
 
-        if self.use_preprocessing_for_m2l:
+        if self.use_fft_for_m2l:
             return fft(input_vector, sac=sac)
         else:
-            # When FFT is turned off, there is no preprocessing needed
-            # Therefore no copying is done and the multipole expansion is sent to
-            # the main M2L routine as it is. This method is used internally in the
-            # the main M2l routine to avoid code duplication.
             return input_vector
 
+    def m2l_preprocess_multipole_nexprs(self, src_expansion):
+        circulant_matrix_mis, _, _ = \
+            self._m2l_translation_classes_dependent_data_mis(src_expansion)
+        return len(circulant_matrix_mis)
+
     def m2l_postprocess_local_exprs(self, src_expansion, m2l_result, src_rscale,
             tgt_rscale, sac):
         circulant_matrix_mis, needed_vector_terms, max_mi = \
@@ -477,7 +486,7 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
         circulant_matrix_ident_to_index = dict((ident, i) for i, ident in
                             enumerate(circulant_matrix_mis))
 
-        if self.use_preprocessing_for_m2l:
+        if self.use_fft_for_m2l:
             n = len(circulant_matrix_mis)
             m2l_result = fft(m2l_result, inverse=True, sac=sac)
             # since we reversed the M2L matrix, we reverse the result
@@ -493,6 +502,9 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
 
         return result
 
+    def m2l_postprocess_local_nexprs(self, src_expansion):
+        return self.m2l_translation_classes_dependent_ndata(src_expansion)
+
     def translate_from(self, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, _fast_version=True,
             m2l_translation_classes_dependent_data=None):
@@ -514,8 +526,6 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
         if isinstance(src_expansion, VolumeTaylorMultipoleExpansionBase):
             circulant_matrix_mis, needed_vector_terms, max_mi = \
                 self._m2l_translation_classes_dependent_data_mis(src_expansion)
-            circulant_matrix_ident_to_index = {ident: i for i, ident in
-                                enumerate(circulant_matrix_mis)}
 
             if not m2l_translation_classes_dependent_data:
                 derivatives = self.m2l_translation_classes_dependent_data(
@@ -523,26 +533,23 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
             else:
                 derivatives = m2l_translation_classes_dependent_data
 
-            if self.use_preprocessing_for_m2l:
-                assert m2l_translation_classes_dependent_data is not None
-                assert len(src_coeff_exprs) == len(
-                        m2l_translation_classes_dependent_data)
-                result = [a*b for a, b in zip(m2l_translation_classes_dependent_data,
-                    src_coeff_exprs)]
+            if self.use_fft_for_m2l:
+                assert len(src_coeff_exprs) == len(derivatives)
+                result = [a*b for a, b in zip(derivatives, src_coeff_exprs)]
             else:
-                derivatives_full = [0]*len(circulant_matrix_mis)
-                for expr, mi in zip(derivatives, needed_vector_terms):
-                    derivatives_full[circulant_matrix_ident_to_index[mi]] = expr
-
-                input_vector = self.m2l_preprocess_multipole_exprs(src_expansion,
-                    src_coeff_exprs, sac, src_rscale)
+                if not self.use_preprocessing_for_m2l:
+                    src_coeff_exprs = self.m2l_preprocess_multipole_exprs(
+                        src_expansion, src_coeff_exprs, sac, src_rscale)
 
-                # Do the matvec
-                output = matvec_toeplitz_upper_triangular(input_vector,
-                    derivatives_full)
+                # Returns a big symbolic sum of matrix entries
+                # (FIXME? Though this is just the correctness-checking
+                # fallback for the FFT anyhow)
+                result = matvec_toeplitz_upper_triangular(src_coeff_exprs,
+                    derivatives)
 
-                result = self.m2l_postprocess_local_exprs(src_expansion, output,
-                    src_rscale, tgt_rscale, sac)
+                if not self.use_preprocessing_for_m2l:
+                    result = self.m2l_postprocess_local_exprs(src_expansion,
+                        result, src_rscale, tgt_rscale, sac)
 
             logger.info("building translation operator: done")
             return result
@@ -682,15 +689,61 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
         logger.info("building translation operator: done")
         return result
 
+    def loopy_translate_from(self, src_expansion):
+        from sumpy.expansion.multipole import VolumeTaylorMultipoleExpansionBase
+
+        if isinstance(src_expansion, VolumeTaylorMultipoleExpansionBase):
+            if self.use_preprocessing_for_m2l:
+                ncoeff_src = self.m2l_preprocess_multipole_nexprs(src_expansion)
+                ncoeff_tgt = self.m2l_postprocess_local_nexprs(src_expansion)
+                icoeff_src = pymbolic.var("icoeff_src")
+                icoeff_tgt = pymbolic.var("icoeff_tgt")
+                domains = [f"{{[icoeff_tgt]: 0<=icoeff_tgt<{ncoeff_tgt} }}"]
+
+                coeff = pymbolic.var("coeff")
+                src_coeffs = pymbolic.var("src_coeffs")
+                m2l_translation_classes_dependent_data = pymbolic.var("data")
+
+                if self.use_fft_for_m2l:
+                    expr = src_coeffs[icoeff_tgt] \
+                            * m2l_translation_classes_dependent_data[icoeff_tgt]
+                else:
+                    toeplitz_first_row = src_coeffs[icoeff_src-icoeff_tgt]
+                    vector = m2l_translation_classes_dependent_data[icoeff_src]
+                    expr = toeplitz_first_row * vector
+                    domains.append(
+                        f"{{[icoeff_src]: icoeff_tgt<=icoeff_src<{ncoeff_src} }}")
+
+                expr = src_coeffs[icoeff_tgt] \
+                    * m2l_translation_classes_dependent_data[icoeff_tgt]
+
+                insns = [
+                    lp.Assignment(
+                        assignee=coeff[icoeff_tgt],
+                        expression=coeff[icoeff_tgt] + expr),
+                ]
+                return lp.make_function(domains, insns,
+                        kernel_data=[
+                            lp.GlobalArg("coeff, src_coeffs, data",
+                                shape=lp.auto),
+                            lp.ValueArg("src_rscale, tgt_rscale"),
+                            ...],
+                        name="e2e",
+                        lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+                        )
+        raise NotImplementedError(
+            f"A direct loopy kernel for translation from "
+            f"{src_expansion} to {self} is not implemented.")
+
 
 class VolumeTaylorLocalExpansion(
         VolumeTaylorExpansion,
         VolumeTaylorLocalExpansionBase):
 
     def __init__(self, kernel, order, use_rscale=None,
-            use_preprocessing_for_m2l=False):
+            use_fft_for_m2l=False, use_preprocessing_for_m2l=None):
         VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale,
-                use_preprocessing_for_m2l)
+                use_fft_for_m2l, use_preprocessing_for_m2l=None)
         VolumeTaylorExpansion.__init__(self, kernel, order, use_rscale)
 
 
@@ -699,9 +752,9 @@ class LinearPDEConformingVolumeTaylorLocalExpansion(
         VolumeTaylorLocalExpansionBase):
 
     def __init__(self, kernel, order, use_rscale=None,
-            use_preprocessing_for_m2l=False):
+            use_fft_for_m2l=False, use_preprocessing_for_m2l=None):
         VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale,
-                use_preprocessing_for_m2l)
+                use_fft_for_m2l, use_preprocessing_for_m2l)
         LinearPDEConformingVolumeTaylorExpansion.__init__(
                 self, kernel, order, use_rscale)
 
@@ -745,9 +798,9 @@ class BiharmonicConformingVolumeTaylorLocalExpansion(
 
 class _FourierBesselLocalExpansion(LocalExpansionBase):
     def __init__(self, kernel, order, use_rscale=None,
-            use_preprocessing_for_m2l=False):
+            use_fft_for_m2l=False, use_preprocessing_for_m2l=None):
 
-        if use_preprocessing_for_m2l:
+        if use_fft_for_m2l:
             # FIXME: expansion with FFT is correct symbolically and can be verified
             # with sympy. However there are numerical issues that we have to deal
             # with. Greengard and Rokhlin 1988 attributes this to numerical
@@ -758,6 +811,7 @@ class _FourierBesselLocalExpansion(LocalExpansionBase):
                              "supported yet.")
 
         super().__init__(kernel, order, use_rscale,
+                use_fft_for_m2l=use_fft_for_m2l,
                 use_preprocessing_for_m2l=use_preprocessing_for_m2l)
 
     def get_storage_index(self, k):
@@ -829,7 +883,7 @@ class _FourierBesselLocalExpansion(LocalExpansionBase):
                     Hankel1(m + j, arg_scale * dvec_len, 0)
                     * sym.exp(sym.I * (m + j) * new_center_angle_rel_old_center))
 
-        if self.use_preprocessing_for_m2l:
+        if self.use_fft_for_m2l:
             order = src_expansion.order
             # For this expansion, we have a mirror image of a Toeplitz matrix.
             # First, we have to take the mirror image of the M2L matrix.
@@ -860,18 +914,20 @@ class _FourierBesselLocalExpansion(LocalExpansionBase):
         for m in src_expansion.get_coefficient_identifiers():
             src_coeff_exprs[src_expansion.get_storage_index(m)] *= src_rscale**abs(m)
 
-        if self.use_preprocessing_for_m2l:
+        if self.use_fft_for_m2l:
             src_coeff_exprs = list(reversed(src_coeff_exprs))
             src_coeff_exprs += [0] * (len(src_coeff_exprs) - 1)
-            res = fft(src_coeff_exprs, sac=sac)
-            return res
+            return fft(src_coeff_exprs, sac=sac)
         else:
             return src_coeff_exprs
 
+    def m2l_preprocess_multipole_nexprs(self, src_expansion):
+        return 2*src_expansion.order + 1
+
     def m2l_postprocess_local_exprs(self, src_expansion, m2l_result, src_rscale,
             tgt_rscale, sac):
 
-        if self.use_preprocessing_for_m2l:
+        if self.use_fft_for_m2l:
             m2l_result = fft(m2l_result, inverse=True, sac=sac)
             m2l_result = m2l_result[:2*self.order+1]
 
@@ -883,6 +939,9 @@ class _FourierBesselLocalExpansion(LocalExpansionBase):
 
         return result
 
+    def m2l_postprocess_local_nexprs(self, src_expansion):
+        return 2*self.order + 1
+
     def translate_from(self, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, m2l_translation_classes_dependent_data=None):
         from sumpy.symbolic import sym_real_norm_2, BesselJ
@@ -915,40 +974,85 @@ class _FourierBesselLocalExpansion(LocalExpansionBase):
             else:
                 derivatives = m2l_translation_classes_dependent_data
 
-            translated_coeffs = []
-            if self.use_preprocessing_for_m2l:
+            if self.use_fft_for_m2l:
                 assert m2l_translation_classes_dependent_data is not None
                 assert len(derivatives) == len(src_coeff_exprs)
-                for a, b in zip(derivatives, src_coeff_exprs):
-                    translated_coeffs.append(a * b)
-                return translated_coeffs
-
-            src_coeff_exprs = self.m2l_preprocess_multipole_exprs(src_expansion,
-                    src_coeff_exprs, sac, src_rscale)
+                translated_coeffs = [a * b for a, b in zip(derivatives,
+                    src_coeff_exprs)]
+            else:
+                if not self.use_preprocessing_for_m2l:
+                    src_coeff_exprs = self.m2l_preprocess_multipole_exprs(
+                        src_expansion, src_coeff_exprs, sac, src_rscale)
 
-            for j in self.get_coefficient_identifiers():
-                translated_coeffs.append(
+                translated_coeffs = [
                     sum(derivatives[m + j + self.order + src_expansion.order]
-                        * src_coeff_exprs[src_expansion.get_storage_index(m)]
-                        for m in src_expansion.get_coefficient_identifiers()))
+                            * src_coeff_exprs[src_expansion.get_storage_index(m)]
+                        for m in src_expansion.get_coefficient_identifiers())
+                    for j in self.get_coefficient_identifiers()]
+
+                if not self.use_preprocessing_for_m2l:
+                    translated_coeffs = self.m2l_postprocess_local_exprs(
+                        src_expansion, translated_coeffs, src_rscale, tgt_rscale,
+                        sac)
 
-            translated_coeffs = self.m2l_postprocess_local_exprs(src_expansion,
-                translated_coeffs, src_rscale, tgt_rscale, sac)
             return translated_coeffs
 
         raise RuntimeError("do not know how to translate %s to %s"
                            % (type(src_expansion).__name__,
                                type(self).__name__))
 
+    def loopy_translate_from(self, src_expansion):
+        if isinstance(src_expansion, self.mpole_expn_class):
+            if self.use_preprocessing_for_m2l:
+                ncoeff_src = self.m2l_preprocess_multipole_nexprs(src_expansion)
+                ncoeff_tgt = self.m2l_postprocess_local_nexprs(src_expansion)
+
+                icoeff_src = pymbolic.var("icoeff_src")
+                icoeff_tgt = pymbolic.var("icoeff_tgt")
+                domains = [f"{{[icoeff_tgt]: 0<=icoeff_tgt<{ncoeff_tgt} }}"]
+
+                coeff = pymbolic.var("coeff")
+                src_coeffs = pymbolic.var("src_coeffs")
+                m2l_translation_classes_dependent_data = pymbolic.var("data")
+
+                if self.use_fft_for_m2l:
+                    expr = src_coeffs[icoeff_tgt] \
+                            * m2l_translation_classes_dependent_data[icoeff_tgt]
+                else:
+                    expr = src_coeffs[icoeff_src] \
+                           * m2l_translation_classes_dependent_data[
+                                   icoeff_tgt + icoeff_src]
+                    domains.append(
+                            f"{{[icoeff_src]: 0<=icoeff_src<{ncoeff_src} }}")
+
+                insns = [
+                    lp.Assignment(
+                        assignee=coeff[icoeff_tgt],
+                        expression=coeff[icoeff_tgt] + expr),
+                ]
+                return lp.make_function(domains, insns,
+                        kernel_data=[
+                            lp.GlobalArg("coeff, src_coeffs, data",
+                                shape=lp.auto),
+                            lp.ValueArg("src_rscale, tgt_rscale"),
+                            ...],
+                        name="e2e",
+                        lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+                        )
+        raise NotImplementedError(
+            f"A direct loopy kernel for translation from "
+            f"{src_expansion} to {self} is not implemented.")
+
 
 class H2DLocalExpansion(_FourierBesselLocalExpansion):
     def __init__(self, kernel, order, use_rscale=None,
-            use_preprocessing_for_m2l=False):
+            use_fft_for_m2l=False, use_preprocessing_for_m2l=None):
         from sumpy.kernel import HelmholtzKernel
         assert (isinstance(kernel.get_base_kernel(), HelmholtzKernel)
                 and kernel.dim == 2)
 
         super().__init__(kernel, order, use_rscale,
+                use_fft_for_m2l=use_fft_for_m2l,
                 use_preprocessing_for_m2l=use_preprocessing_for_m2l)
 
         from sumpy.expansion.multipole import H2DMultipoleExpansion
@@ -960,12 +1064,13 @@ class H2DLocalExpansion(_FourierBesselLocalExpansion):
 
 class Y2DLocalExpansion(_FourierBesselLocalExpansion):
     def __init__(self, kernel, order, use_rscale=None,
-            use_preprocessing_for_m2l=False):
+            use_fft_for_m2l=False, use_preprocessing_for_m2l=None):
         from sumpy.kernel import YukawaKernel
         assert (isinstance(kernel.get_base_kernel(), YukawaKernel)
                 and kernel.dim == 2)
 
         super().__init__(kernel, order, use_rscale,
+                use_fft_for_m2l=use_fft_for_m2l,
                 use_preprocessing_for_m2l=use_preprocessing_for_m2l)
 
         from sumpy.expansion.multipole import Y2DMultipoleExpansion
diff --git a/sumpy/fmm.py b/sumpy/fmm.py
index bfeadbe5..7403a1db 100644
--- a/sumpy/fmm.py
+++ b/sumpy/fmm.py
@@ -40,7 +40,7 @@ from sumpy import (
         E2EFromCSR, M2LUsingTranslationClassesDependentData,
         E2EFromChildren, E2EFromParent,
         M2LGenerateTranslationClassesDependentData,
-        M2LPreprocessMultipole)
+        M2LPreprocessMultipole, M2LPostprocessLocal)
 from sumpy.tools import to_complex_dtype
 
 
@@ -67,7 +67,7 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
             local_expansion_factory,
             target_kernels, exclude_self=False, use_rscale=None,
             strength_usage=None, source_kernels=None,
-            use_preprocessing_for_m2l=False):
+            use_fft_for_m2l=False, use_preprocessing_for_m2l=None):
         """
         :arg multipole_expansion_factory: a callable of a single argument (order)
             that returns a multipole expansion.
@@ -77,6 +77,10 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
         :arg exclude_self: whether the self contribution should be excluded
         :arg strength_usage: passed unchanged to p2l, p2m and p2p.
         :arg source_kernels: passed unchanged to p2l, p2m and p2p.
+        :arg use_fft_for_m2l: Use an FFT based multipole-to-local expansion.
+        :arg use_preprocessing_for_m2l: do preprocessing of the source multipole
+            expansion and postprocessing of the target local expansion for
+            multipole-to-local expansion.
         """
         self.multipole_expansion_factory = multipole_expansion_factory
         self.local_expansion_factory = local_expansion_factory
@@ -85,7 +89,11 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
         self.exclude_self = exclude_self
         self.use_rscale = use_rscale
         self.strength_usage = strength_usage
-        self.use_preprocessing_for_m2l = use_preprocessing_for_m2l
+        self.use_fft_for_m2l = use_fft_for_m2l
+        if use_preprocessing_for_m2l is None:
+            self.use_preprocessing_for_m2l = use_fft_for_m2l
+        else:
+            self.use_preprocessing_for_m2l = use_preprocessing_for_m2l
 
         super().__init__()
 
@@ -103,6 +111,7 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
     @memoize_method
     def local_expansion(self, order):
         return self.local_expansion_factory(order, self.use_rscale,
+                use_fft_for_m2l=self.use_fft_for_m2l,
                 use_preprocessing_for_m2l=self.use_preprocessing_for_m2l)
 
     @memoize_method
@@ -148,6 +157,12 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
                 self.multipole_expansion(src_order),
                 self.local_expansion(tgt_order))
 
+    @memoize_method
+    def m2l_postprocess_local_kernel(self, src_order, tgt_order):
+        return M2LPostprocessLocal(self.cl_context,
+                self.multipole_expansion(src_order),
+                self.local_expansion(tgt_order))
+
     @memoize_method
     def l2l(self, src_order, tgt_order):
         return E2EFromParent(self.cl_context,
@@ -310,7 +325,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
 
         self.dtype = dtype
 
-        if not self.tree_indep.use_preprocessing_for_m2l:
+        if not self.tree_indep.use_fft_for_m2l:
             # If not FFT, we don't need complex dtypes
             self.preprocessed_mpole_dtype = dtype
         elif preprocessed_mpole_dtype is not None:
@@ -345,7 +360,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
         if base_kernel.is_translation_invariant:
             if translation_classes_data is None:
                 from warnings import warn
-                if self.tree_indep.use_preprocessing_for_m2l:
+                if self.tree_indep.use_fft_for_m2l:
                     raise NotImplementedError(
                          "FFT based List 2 (multipole-to-local) translations "
                          "without translation_classes_data argument is not "
@@ -358,14 +373,14 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                          "to the wrangler for optimized List 2.",
                          SumpyTranslationClassesDataNotSuppliedWarning,
                          stacklevel=2)
-                self.supports_optimized_m2l = False
+                self.supports_translation_classes = False
             else:
-                self.supports_optimized_m2l = True
+                self.supports_translation_classes = True
         else:
-            self.supports_optimized_m2l = False
+            self.supports_translation_classes = False
 
         self.translation_classes_data = translation_classes_data
-        self.use_preprocessing_for_m2l = self.tree_indep.use_preprocessing_for_m2l
+        self.use_fft_for_m2l = self.tree_indep.use_fft_for_m2l
 
     # {{{ data vector utilities
 
@@ -465,7 +480,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
         def order_to_size(order):
             mpole_expn = self.tree_indep.multipole_expansion(order)
             local_expn = self.tree_indep.local_expansion(order)
-            res = local_expn.m2l_translation_classes_dependent_ndata(mpole_expn)
+            res = local_expn.m2l_preprocess_multipole_nexprs(mpole_expn)
             return res
 
         return build_csr_level_starts(self.level_orders, order_to_size,
@@ -485,6 +500,11 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
         return (box_start,
                 mpole_exps[expn_start:expn_stop].reshape(box_stop-box_start, -1))
 
+    m2l_work_array_view = m2l_preproc_mpole_expansions_view
+    m2l_work_array_zeros = m2l_preproc_mpole_expansion_zeros
+    m2l_work_array_level_starts = \
+            m2l_preproc_mpole_expansions_level_starts
+
     def output_zeros(self, template_ary):
         """Return a potentials array (which must support addition) capable of
         holding a potential value for each target in the tree. Note that
@@ -707,16 +727,15 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
 
             m2l_translation_classes_dependent_data = \
                     m2l_translation_classes_dependent_data.with_queue(None)
-
         return m2l_translation_classes_dependent_data
 
     def _add_m2l_precompute_kwargs(self, kwargs_for_m2l,
             lev):
-        """This method is used for addin the information needed for a
+        """This method is used for adding the information needed for a
         multipole-to-local translation with precomputation to the keywords
         passed to multipole-to-local translation.
         """
-        if not self.supports_optimized_m2l:
+        if not self.supports_translation_classes:
             return
         m2l_translation_classes_dependent_data = \
                 self.multipole_to_local_precompute()
@@ -736,18 +755,19 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
             target_boxes, src_box_starts, src_box_lists,
             mpole_exps):
 
-        precompute_evts = []
+        preprocess_evts = []
         queue = mpole_exps.queue
+        local_exps = self.local_expansion_zeros(mpole_exps)
 
-        if self.use_preprocessing_for_m2l:
+        if self.tree_indep.use_preprocessing_for_m2l:
             preprocessed_mpole_exps = \
-                    self.m2l_preproc_mpole_expansion_zeros(mpole_exps)
+                self.m2l_preproc_mpole_expansion_zeros(mpole_exps)
             for lev in range(self.tree.nlevels):
                 order = self.level_orders[lev]
                 preprocess_mpole_kernel = \
                     self.tree_indep.m2l_preprocess_mpole_kernel(order, order)
 
-                source_level_start_ibox, source_mpoles_view = \
+                _, source_mpoles_view = \
                         self.multipole_expansions_view(mpole_exps, lev)
 
                 _, preprocessed_source_mpoles_view = \
@@ -766,14 +786,17 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                     src_rscale=level_to_rscale(self.tree, lev),
                     **self.kernel_extra_kwargs
                 )
-                precompute_evts.append(evt)
+                preprocess_evts.append(evt)
             mpole_exps = preprocessed_mpole_exps
+            m2l_work_array = self.m2l_work_array_zeros(local_exps)
             mpole_exps_view_func = self.m2l_preproc_mpole_expansions_view
+            local_exps_view_func = self.m2l_work_array_view
         else:
+            m2l_work_array = local_exps
             mpole_exps_view_func = self.multipole_expansions_view
+            local_exps_view_func = self.local_expansions_view
 
-        events = []
-        local_exps = self.local_expansion_zeros(mpole_exps)
+        translate_evts = []
 
         for lev in range(self.tree.nlevels):
             start, stop = level_start_target_box_nrs[lev:lev+2]
@@ -781,17 +804,18 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                 continue
 
             order = self.level_orders[lev]
-            m2l = self.tree_indep.m2l(order, order, self.supports_optimized_m2l)
+            m2l = self.tree_indep.m2l(order, order,
+                    self.supports_translation_classes)
 
             source_level_start_ibox, source_mpoles_view = \
                     mpole_exps_view_func(mpole_exps, lev)
-            target_level_start_ibox, target_local_exps_view = \
-                    self.local_expansions_view(local_exps, lev)
+            target_level_start_ibox, target_locals_view = \
+                    local_exps_view_func(m2l_work_array, lev)
 
             kwargs = dict(
                     src_expansions=source_mpoles_view,
                     src_base_ibox=source_level_start_ibox,
-                    tgt_expansions=target_local_exps_view,
+                    tgt_expansions=target_locals_view,
                     tgt_base_ibox=target_level_start_ibox,
 
                     target_boxes=target_boxes[start:stop],
@@ -809,11 +833,45 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                     kwargs["m2l_translation_classes_dependent_data"].size == 0:
                 # There is nothing to do for this level
                 continue
-            evt, _ = m2l(queue, **kwargs, wait_for=precompute_evts)
+            evt, _ = m2l(queue, **kwargs, wait_for=preprocess_evts)
 
-            events.append(evt)
+            translate_evts.append(evt)
 
-        return (local_exps, SumpyTimingFuture(queue, events))
+        postprocess_evts = []
+
+        if self.tree_indep.use_preprocessing_for_m2l:
+            for lev in range(self.tree.nlevels):
+                order = self.level_orders[lev]
+                postprocess_local_kernel = \
+                    self.tree_indep.m2l_postprocess_local_kernel(order, order)
+
+                _, target_locals_view = \
+                        self.local_expansions_view(local_exps, lev)
+
+                _, target_locals_before_postprocessing_view = \
+                        self.m2l_work_array_view(
+                                m2l_work_array, lev)
+
+                tr_classes = self.m2l_translation_class_level_start_box_nrs()
+                if tr_classes[lev] == tr_classes[lev + 1]:
+                    # There is no M2L happening in this level
+                    continue
+
+                evt, _ = postprocess_local_kernel(
+                    queue,
+                    tgt_expansions=target_locals_view,
+                    tgt_expansions_before_postprocessing=(
+                        target_locals_before_postprocessing_view),
+                    src_rscale=level_to_rscale(self.tree, lev),
+                    tgt_rscale=level_to_rscale(self.tree, lev),
+                    wait_for=translate_evts,
+                    **self.kernel_extra_kwargs,
+                )
+                postprocess_evts.append(evt)
+
+        timing_events = preprocess_evts + translate_evts + postprocess_evts
+
+        return (local_exps, SumpyTimingFuture(queue, timing_events))
 
     def eval_multipoles(self,
             target_boxes_by_source_level, source_boxes_by_level, mpole_exps):
diff --git a/test/test_fmm.py b/test/test_fmm.py
index 3e2f84fd..d1fd9690 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -190,7 +190,7 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
                 ctx,
                 partial(mpole_expn_class, knl),
                 partial(local_expn_class, knl),
-                target_kernels, use_preprocessing_for_m2l=use_fft)
+                target_kernels, use_fft_for_m2l=use_fft)
 
         with warnings.catch_warnings():
             if not optimized_m2l:
-- 
GitLab