From 9849af5ab27cd57ee6a8a94664d25f6abba7bd6b Mon Sep 17 00:00:00 2001
From: Isuru Fernando <idf2@illinois.edu>
Date: Mon, 1 Aug 2022 15:52:46 -0500
Subject: [PATCH] FFT using pyvkfft and use loopy callables (#114)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Use a separate class for M2L translation

* Fix docs and caching

* Fix p2p warning

* Use VkFFT for M2L generate data

* Fix profiling events

* simplify m2l data zeros

* Add pyvkfft to requirements

* Fix flake8 warning

* Fix typo

* VkFFT for M2L preprocess local

* vkfft for postprocess local

* Fix AggregateProfilingEvent

* Fix another typo

* M2L Translation Factory

* vim markers

* Fix tests

* Fix toys

* Fix test_m2l_toeplitz

* Fix more tests

* Use a better rscale to get the test passing

* Use pytential dev branch

* remove whitespace on blank line

* Try 2r/order instead of r/order

* fix using updated pytential

* Fix tests

* use pytential branch with pyvkfft req

* Add explanation about caller being responsible for the FFT

* Fix for bessel

* Add pyvkfft to setup.py reqs

* use list comprehension

* Type annotations

* fix vim marker

* remove unused function

* m2l_data_inner -> m2l_data

* more descriptive name for child_knl

* knl -> expr_knl for clarity

* move loop unroll to optimized

* Add explanation about translation_classes_dependent_data_loopy_knl

* make coeffs output only and rewrite

* Re-arrange m2l so that event processing is easier

* flake8: single quotes -> double quotes

* Fix data not being input

* make args to cached_vkfft_app explicit

* cache vkfftapp in wrangler

* keep coeffs is_input and is_output for e2e

* out-of-place fft

* Use a separate queue for configuration

* allocate array for out-of-place

* fix typo

* Remove caching of opencl fft app

* Comment out pytentual fork

* fix vkfft queues

* use private API for now

* Add comment on pyvkfft PR

* remove inplace

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 .github/workflows/ci.yml |   6 +-
 .test-conda-env-py3.yml  |   1 +
 requirements.txt         |   1 +
 setup.py                 |   1 +
 sumpy/e2e.py             | 237 +++++++++++------------------
 sumpy/expansion/m2l.py   | 311 +++++++++++++++++++++++++++++++++------
 sumpy/fmm.py             | 177 ++++++++++++++--------
 sumpy/p2p.py             |   2 +-
 sumpy/tools.py           |  94 +++++++++++-
 test/test_fmm.py         |  15 +-
 10 files changed, 588 insertions(+), 257 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index c369d1fc..a66e81b7 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -88,9 +88,9 @@ jobs:
             run: |
                 curl -L -O https://tiker.net/ci-support-v0
                 . ./ci-support-v0
-                if [[ "$DOWNSTREAM_PROJECT" == "pytential" && "$GITHUB_HEAD_REF" == "m2l" ]]; then
-                   DOWNSTREAM_PROJECT=https://github.com/isuruf/pytential.git@m2l_translation
-                fi
+                # if [[ "$DOWNSTREAM_PROJECT" == "pytential" && "$GITHUB_HEAD_REF" == "fft" ]]; then
+                #    DOWNSTREAM_PROJECT=https://github.com/isuruf/pytential.git@pyvkfft
+                # fi
                 test_downstream "$DOWNSTREAM_PROJECT"
 
 # vim: sw=4
diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml
index 916e1a81..575e02e2 100644
--- a/.test-conda-env-py3.yml
+++ b/.test-conda-env-py3.yml
@@ -15,3 +15,4 @@ dependencies:
 - python-symengine
 - pyfmmlib
 - pyrsistent
+- pyvkfft
diff --git a/requirements.txt b/requirements.txt
index ed07527b..a26c33e3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
 numpy
 sympy
 pyrsistent
+pyvkfft
 git+https://github.com/inducer/pytools.git#egg=pytools
 git+https://github.com/inducer/pymbolic.git#egg=pymbolic
 git+https://github.com/inducer/islpy.git#egg=islpy
diff --git a/setup.py b/setup.py
index 5ca54417..f9f0be9b 100644
--- a/setup.py
+++ b/setup.py
@@ -106,5 +106,6 @@ setup(
         "dataclasses>=0.7;python_version<='3.6'",
         "sympy>=0.7.2",
         "pymbolic>=2021.1",
+        "pyvkfft>=2022.1",
     ],
 )
diff --git a/sumpy/e2e.py b/sumpy/e2e.py
index b35f86cd..5bd0ab2b 100644
--- a/sumpy/e2e.py
+++ b/sumpy/e2e.py
@@ -361,16 +361,16 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
 
         domains = []
         insns = self.get_translation_loopy_insns(result_dtype)
-        coeff = pymbolic.var("coeff")
+        tgt_coeffs = pymbolic.var("tgt_coeffs")
         for i in range(ncoeff_tgt):
             expr = pymbolic.var(f"tgt_coeff{i}")
-            insn = lp.Assignment(assignee=coeff[i],
-                    expression=coeff[i] + expr)
+            insn = lp.Assignment(assignee=tgt_coeffs[i],
+                    expression=tgt_coeffs[i] + expr)
             insns.append(insn)
 
         return lp.make_function(domains, insns,
                         kernel_data=[
-                            lp.GlobalArg("coeff", shape=(ncoeff_tgt,),
+                            lp.GlobalArg("tgt_coeffs", shape=(ncoeff_tgt,),
                                 is_output=True, is_input=True),
                             lp.GlobalArg("src_coeffs", shape=(ncoeff_src,)),
                             lp.GlobalArg("data", shape=(ndata,)),
@@ -420,7 +420,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                     <> isrc_start = src_box_starts[itgt_box]
                     <> isrc_stop = src_box_starts[itgt_box+1]
                     for icoeff_tgt
-                        <> coeffs[icoeff_tgt] = 0 {id=init_coeffs, dup=icoeff_tgt}
+                        <> tgt_expansion[icoeff_tgt] = 0 \
+                            {id=init_coeffs, dup=icoeff_tgt}
                     end
                     for isrc_box
                         <> src_ibox = src_box_lists[isrc_box] \
@@ -430,8 +431,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                         <> translation_class_rel = \
                                 translation_class - translation_classes_level_start \
                                 {id=translation_offset}
-                        [icoeff_tgt]: coeffs[icoeff_tgt] = e2e(
-                            [icoeff_tgt]: coeffs[icoeff_tgt],
+                        [icoeff_tgt]: tgt_expansion[icoeff_tgt] = e2e(
+                            [icoeff_tgt]: tgt_expansion[icoeff_tgt],
                             [icoeff_src]: src_expansions[src_ibox - src_base_ibox,
                                 icoeff_src],
                             [idep]: m2l_translation_classes_dependent_data[
@@ -441,7 +442,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                             )  {dep=init_coeffs,id=update_coeffs}
                     end
                     tgt_expansions[tgt_ibox - tgt_base_ibox, icoeff_tgt] = \
-                            coeffs[icoeff_tgt] {dep=update_coeffs, dup=icoeff_tgt}
+                            tgt_expansion[icoeff_tgt] \
+                            {dep=update_coeffs, dup=icoeff_tgt}
                 end
                 """],
                 [
@@ -533,55 +535,34 @@ class M2LGenerateTranslationClassesDependentData(E2EBase):
     """
     default_name = "m2l_generate_translation_classes_dependent_data"
 
-    def get_translation_loopy_insns(self, result_dtype):
-        from sumpy.symbolic import make_sym_vector
-        dvec = make_sym_vector("d", self.dim)
-
-        src_rscale = sym.Symbol("src_rscale")
-
-        m2l_translation = self.tgt_expansion.m2l_translation
-
-        from sumpy.assignment_collection import SymbolicAssignmentCollection
-        sac = SymbolicAssignmentCollection()
-        tgt_coeff_names = [
-            sac.assign_unique(
-                f"m2l_translation_classes_dependent_expr{i}", coeff_i)
-            for i, coeff_i in enumerate(
-                m2l_translation.translation_classes_dependent_data(
-                    self.tgt_expansion, self.src_expansion, src_rscale,
-                    dvec, sac=sac))]
-
-        sac.run_global_cse()
-
-        from sumpy.codegen import to_loopy_insns
-        return to_loopy_insns(
-                sac.assignments.items(),
-                vector_names={"d"},
-                pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()],
-                retain_names=tgt_coeff_names,
-                complex_dtype=to_complex_dtype(result_dtype),
-                )
-
     def get_kernel(self, result_dtype):
+        m2l_translation = self.tgt_expansion.m2l_translation
         m2l_translation_classes_dependent_ndata = \
-            self.tgt_expansion.m2l_translation.translation_classes_dependent_ndata(
+            m2l_translation.translation_classes_dependent_ndata(
                 self.tgt_expansion, self.src_expansion)
+
+        translation_classes_data_knl = \
+            m2l_translation.translation_classes_dependent_data_loopy_knl(
+                self.tgt_expansion, self.src_expansion, result_dtype)
+
         from sumpy.tools import gather_loopy_arguments
         loopy_knl = lp.make_kernel(
                 [
                     "{[itr_class]: 0<=itr_class<ntranslation_classes}",
                     "{[idim]: 0<=idim<dim}",
+                    "{[idata]: 0<=idata<m2l_translation_classes_dependent_ndata}",
                     ],
                 ["""
                 for itr_class
                     <> d[idim] = m2l_translation_vectors[idim, \
-                            itr_class + translation_classes_level_start]
-
-                    """] + self.get_translation_loopy_insns(result_dtype) + ["""
-                    m2l_translation_classes_dependent_data[itr_class, {idx}] = \
-                            m2l_translation_classes_dependent_expr{idx}
-                    """.format(idx=i) for i in range(
-                        m2l_translation_classes_dependent_ndata)] + ["""
+                            itr_class + translation_classes_level_start] \
+                            {id=set_d,dup=idim}
+                    [idata]: m2l_translation_classes_dependent_data[
+                            itr_class, idata] = \
+                        m2l_data(
+                            src_rscale,
+                            [idim]: d[idim],
+                        ) {id=update,dep=set_d}
                 end
                 """],
                 [
@@ -600,14 +581,18 @@ class M2LGenerateTranslationClassesDependentData(E2EBase):
                 name=self.name,
                 assumptions="ntranslation_classes>=1",
                 default_offset=lp.auto,
-                fixed_parameters=dict(dim=self.dim),
+                fixed_parameters=dict(
+                    dim=self.dim,
+                    m2l_translation_classes_dependent_ndata=(
+                        m2l_translation_classes_dependent_ndata)),
                 lang_version=MOST_RECENT_LANGUAGE_VERSION
                 )
 
-        for knl in [self.src_expansion.kernel, self.tgt_expansion.kernel]:
-            loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
+        for expr_knl in [self.src_expansion.kernel, self.tgt_expansion.kernel]:
+            loopy_knl = expr_knl.prepare_loopy_kernel(loopy_knl)
 
-        loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
+        loopy_knl = lp.merge([loopy_knl, translation_classes_data_knl])
+        loopy_knl = lp.inline_callable_kernel(loopy_knl, "m2l_data")
         loopy_knl = lp.set_options(loopy_knl,
                 enforce_variable_access_ordered="no_check")
 
@@ -616,6 +601,7 @@ class M2LGenerateTranslationClassesDependentData(E2EBase):
     def get_optimized_kernel(self, result_dtype):
         # FIXME
         knl = self.get_kernel(result_dtype)
+        knl = lp.tag_inames(knl, "idim*:unr")
         knl = lp.tag_inames(knl, {"itr_class": "g.0"})
 
         return knl
@@ -656,54 +642,29 @@ class M2LPreprocessMultipole(E2EBase):
 
     default_name = "m2l_preprocess_multipole"
 
-    def get_loopy_insns(self, result_dtype):
-        src_coeff_exprs = [
-            sym.Symbol(f"src_coeff{i}")
-            for i in range(len(self.src_expansion))]
-
-        src_rscale = sym.Symbol("src_rscale")
-
-        from sumpy.assignment_collection import SymbolicAssignmentCollection
-        sac = SymbolicAssignmentCollection()
-
-        preprocessed_src_coeff_names = [
-                sac.assign_unique(f"preprocessed_src_coeff{i}", coeff_i)
-                for i, coeff_i in enumerate(
-                    self.tgt_expansion.m2l_translation.preprocess_multipole_exprs(
-                        self.tgt_expansion, self.src_expansion, src_coeff_exprs,
-                        sac=sac, src_rscale=src_rscale))]
-
-        sac.run_global_cse()
-
-        from sumpy.codegen import to_loopy_insns
-        return to_loopy_insns(
-                sac.assignments.items(),
-                vector_names={"d"},
-                pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()],
-                retain_names=preprocessed_src_coeff_names,
-                complex_dtype=to_complex_dtype(result_dtype),
-                )
-
     def get_kernel(self, result_dtype):
+        m2l_translation = self.tgt_expansion.m2l_translation
         nsrc_coeffs = len(self.src_expansion)
         npreprocessed_src_coeffs = \
-            self.tgt_expansion.m2l_translation.preprocess_multipole_nexprs(
-                self.tgt_expansion, self.src_expansion)
+            m2l_translation.preprocess_multipole_nexprs(self.tgt_expansion,
+                self.src_expansion)
+        single_box_preprocess_knl = m2l_translation.preprocess_multipole_loopy_knl(
+            self.tgt_expansion, self.src_expansion, result_dtype)
+
         from sumpy.tools import gather_loopy_arguments
         loopy_knl = lp.make_kernel(
                 [
                     "{[isrc_box]: 0<=isrc_box<nsrc_boxes}",
-                    ],
+                    "{[isrc_coeff]: 0<=isrc_coeff<nsrc_coeffs}",
+                    "{[itgt_coeff]: 0<=itgt_coeff<npreprocessed_src_coeffs}",
+                ],
                 ["""
                 for isrc_box
-                """] + ["""
-                    <> src_coeff{idx} = src_expansions[isrc_box, {idx}]
-                """.format(idx=i) for i in range(nsrc_coeffs)] + [
-                ] + self.get_loopy_insns(result_dtype) + ["""
-                    preprocessed_src_expansions[isrc_box, {idx}] = \
-                        preprocessed_src_coeff{idx}
-                    """.format(idx=i) for i in range(
-                        npreprocessed_src_coeffs)] + ["""
+                    [itgt_coeff]: preprocessed_src_expansions[isrc_box, itgt_coeff] \
+                        = m2l_preprocess_inner(
+                            src_rscale,
+                            [isrc_coeff]: src_expansions[isrc_box, isrc_coeff],
+                        )
                 end
                 """],
                 [
@@ -718,22 +679,27 @@ class M2LPreprocessMultipole(E2EBase):
                 ] + gather_loopy_arguments([self.src_expansion, self.tgt_expansion]),
                 name=self.name,
                 assumptions="nsrc_boxes>=1",
+                fixed_parameters=dict(
+                    nsrc_coeffs=nsrc_coeffs,
+                    npreprocessed_src_coeffs=npreprocessed_src_coeffs),
                 default_offset=lp.auto,
-                fixed_parameters=dict(dim=self.dim),
-                lang_version=MOST_RECENT_LANGUAGE_VERSION
+                lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
                 )
 
         for expn in [self.src_expansion.kernel, self.tgt_expansion.kernel]:
             loopy_knl = expn.prepare_loopy_kernel(loopy_knl)
 
-        loopy_knl = lp.set_options(loopy_knl,
-                enforce_variable_access_ordered="no_check")
+        loopy_knl = lp.merge([loopy_knl, single_box_preprocess_knl])
+        loopy_knl = lp.inline_callable_kernel(loopy_knl, "m2l_preprocess_inner")
+
         return loopy_knl
 
     def get_optimized_kernel(self, result_dtype):
         # FIXME
         knl = self.get_kernel(result_dtype)
-        knl = lp.split_iname(knl, "isrc_box", 16, outer_tag="g.0")
+        knl = lp.split_iname(knl, "isrc_box", 64, outer_tag="g.0",
+                             within=f"in_kernel:{self.name}")
+        knl = lp.add_inames_for_unused_hw_axes(knl)
         return knl
 
     def __call__(self, queue, **kwargs):
@@ -744,8 +710,10 @@ class M2LPreprocessMultipole(E2EBase):
         preprocessed_src_expansions = kwargs.pop("preprocessed_src_expansions")
         result_dtype = preprocessed_src_expansions.dtype
         knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
+
         return knl(queue,
-            preprocessed_src_expansions=preprocessed_src_expansions, **kwargs)
+                preprocessed_src_expansions=preprocessed_src_expansions, **kwargs)
+
 # }}}
 
 
@@ -756,68 +724,32 @@ class M2LPostprocessLocal(E2EBase):
 
     default_name = "m2l_postprocess_local"
 
-    def get_loopy_insns(self, result_dtype):
+    def get_kernel(self, result_dtype):
         m2l_translation = self.tgt_expansion.m2l_translation
-        ncoeffs_before_postprocessing = \
+        ntgt_coeffs = len(self.tgt_expansion)
+        ntgt_coeffs_before_postprocessing = \
             m2l_translation.postprocess_local_nexprs(self.tgt_expansion,
-                                                     self.src_expansion)
-
-        tgt_coeff_exprs_before_postprocessing = [
-            sym.Symbol(f"tgt_coeff_before_postprocessing{i}")
-            for i in range(ncoeffs_before_postprocessing)]
-
-        src_rscale = sym.Symbol("src_rscale")
-        tgt_rscale = sym.Symbol("tgt_rscale")
-
-        from sumpy.assignment_collection import SymbolicAssignmentCollection
-        sac = SymbolicAssignmentCollection()
-
-        tgt_coeff_exprs = m2l_translation.postprocess_local_exprs(
-            self.tgt_expansion, self.src_expansion,
-            tgt_coeff_exprs_before_postprocessing,
-            sac=sac, src_rscale=src_rscale, tgt_rscale=tgt_rscale)
-
-        if result_dtype in (np.float32, np.float64):
-            real_func = sym.Function("real")
-            tgt_coeff_exprs = [real_func(expr) for expr in
-                    tgt_coeff_exprs]
-
-        tgt_coeff_names = [
-            sac.assign_unique(f"tgt_coeff{i}", coeff_i)
-            for i, coeff_i in enumerate(tgt_coeff_exprs)]
-
-        sac.run_global_cse()
+                self.src_expansion)
 
-        from sumpy.codegen import to_loopy_insns
-        return to_loopy_insns(
-                sac.assignments.items(),
-                vector_names={"d"},
-                pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()],
-                retain_names=tgt_coeff_names,
-                complex_dtype=to_complex_dtype(result_dtype),
-                )
+        single_box_postprocess_knl = m2l_translation.postprocess_local_loopy_knl(
+            self.tgt_expansion, self.src_expansion, result_dtype)
 
-    def get_kernel(self, result_dtype):
-        ntgt_coeffs = len(self.tgt_expansion)
-        ntgt_coeffs_before_postprocessing = \
-            self.tgt_expansion.m2l_translation.postprocess_local_nexprs(
-                self.tgt_expansion, self.src_expansion)
         from sumpy.tools import gather_loopy_arguments
         loopy_knl = lp.make_kernel(
                 [
                     "{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
-                    ],
+                    "{[isrc_coeff]: 0<=isrc_coeff<nsrc_coeffs}",
+                    "{[itgt_coeff]: 0<=itgt_coeff<ntgt_coeffs}",
+                ],
                 ["""
                 for itgt_box
-                """] + ["""
-                    <> tgt_coeff_before_postprocessing{idx} = \
-                            tgt_expansions_before_postprocessing[itgt_box, {idx}]
-                """.format(idx=i) for i in range(
-                    ntgt_coeffs_before_postprocessing)]
-                + self.get_loopy_insns(result_dtype) + ["""
-                    tgt_expansions[itgt_box, {idx}] = \
-                        tgt_coeff{idx}
-                    """.format(idx=i) for i in range(ntgt_coeffs)] + ["""
+                    [itgt_coeff]: tgt_expansions[itgt_box, itgt_coeff] = \
+                        m2l_postprocess_inner(
+                            tgt_rscale,
+                            src_rscale,
+                            [isrc_coeff]: tgt_expansions_before_postprocessing[ \
+                            itgt_box, isrc_coeff],
+                       )
                 end
                 """],
                 [
@@ -834,13 +766,20 @@ class M2LPostprocessLocal(E2EBase):
                 name=self.name,
                 assumptions="ntgt_boxes>=1",
                 default_offset=lp.auto,
-                fixed_parameters=dict(dim=self.dim),
+                fixed_parameters=dict(
+                    dim=self.dim,
+                    nsrc_coeffs=ntgt_coeffs_before_postprocessing,
+                    ntgt_coeffs=ntgt_coeffs,
+                ),
                 lang_version=MOST_RECENT_LANGUAGE_VERSION
                 )
 
         for expn in [self.src_expansion.kernel, self.tgt_expansion.kernel]:
             loopy_knl = expn.prepare_loopy_kernel(loopy_knl)
 
+        loopy_knl = lp.merge([loopy_knl, single_box_postprocess_knl])
+        loopy_knl = lp.inline_callable_kernel(loopy_knl, "m2l_postprocess_inner")
+
         loopy_knl = lp.set_options(loopy_knl,
                 enforce_variable_access_ordered="no_check")
         return loopy_knl
@@ -859,8 +798,8 @@ class M2LPostprocessLocal(E2EBase):
         tgt_expansions = kwargs.pop("tgt_expansions")
         result_dtype = tgt_expansions.dtype
         knl = self.get_cached_optimized_kernel(result_dtype=result_dtype)
-        return knl(queue,
-            tgt_expansions=tgt_expansions, **kwargs)
+
+        return knl(queue, tgt_expansions=tgt_expansions, **kwargs)
 
 # }}}
 
diff --git a/sumpy/expansion/m2l.py b/sumpy/expansion/m2l.py
index 3120595e..796e951c 100644
--- a/sumpy/expansion/m2l.py
+++ b/sumpy/expansion/m2l.py
@@ -24,10 +24,10 @@ from typing import Tuple, Any
 
 import pymbolic
 import loopy as lp
+import numpy as np
 import sumpy.symbolic as sym
 from sumpy.tools import (
-        add_to_sac, fft,
-        matvec_toeplitz_upper_triangular)
+        add_to_sac, matvec_toeplitz_upper_triangular)
 
 import logging
 logger = logging.getLogger(__name__)
@@ -164,6 +164,9 @@ class M2LTranslationBase:
         distance between per level, these can be precomputed for the tree.
         In :mod:`boxtree`, these distances are referred to as translation
         classes.
+
+        When FFT is turned on, the output expressions are assumed to be
+        transformed into Fourier space at the end by the caller.
         """
         return tuple()
 
@@ -178,17 +181,27 @@ class M2LTranslationBase:
         """
         return 0
 
+    def translation_classes_dependent_data_loopy_knl(self, tgt_expansion,
+            src_expansion, result_dtype):
+        """Return a :mod:`loopy` kernel that calculates the data described by
+        :func:`~sumpy.expansion.m2l.M2LTranslationBase.translation_classes_dependent_data`.
+        :arg result_dtype: The :mod:`numpy` type of the result.
+        """
+        return translation_classes_dependent_data_loopy_knl(tgt_expansion,
+            src_expansion, result_dtype)
+
     def preprocess_multipole_exprs(self, tgt_expansion, src_expansion,
             src_coeff_exprs, sac, src_rscale):
         """Return the preprocessed multipole expansion for an optimized M2L.
         Preprocessing happens once per source box before M2L translation is done.
 
-        When FFT is turned on, the input expressions are transformed into Fourier
-        space. These expressions are used in a separate :mod:`loopy` kernel
-        to avoid having to transform for each target and source box pair.
-        When FFT is turned off, the expressions are equal to the multipole
-        expansion coefficients with zeros added
-        to make the M2L computation a circulant matvec.
+        These expressions are used in a separate :mod:`loopy` kernel
+        to avoid having to process for each target and source box pair.
+        When FFT is turned on, the output expressions are assumed to be
+        transformed into Fourier space at the end by the caller.
+        When FFT is turned off, the output expressions are equal to the multipole
+        expansion coefficients with zeros added to make the M2L computation a
+        circulant matvec.
         """
         raise NotImplementedError
 
@@ -211,8 +224,8 @@ class M2LTranslationBase:
         is done and before storing the expansion coefficients for the local
         expansion.
 
-        When FFT is turned on, the output expressions are transformed from Fourier
-        space back to the original space.
+        When FFT is turned on, the output expressions are assumed to have been
+        transformed from Fourier space back to the original space by the caller.
         """
         raise NotImplementedError
 
@@ -397,9 +410,9 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase):
 
     def preprocess_multipole_exprs(self, tgt_expansion, src_expansion,
             src_coeff_exprs, sac, src_rscale):
-        circulant_matrix_mis, needed_vector_terms, max_mi = \
-                self._translation_classes_dependent_data_mis(tgt_expansion,
-                                                                 src_expansion)
+        circulant_matrix_mis, _, _ = \
+            self._translation_classes_dependent_data_mis(tgt_expansion,
+                src_expansion)
         circulant_matrix_ident_to_index = {ident: i for i, ident in
                             enumerate(circulant_matrix_mis)}
 
@@ -416,12 +429,65 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase):
     def preprocess_multipole_nexprs(self, tgt_expansion, src_expansion):
         circulant_matrix_mis, _, _ = \
             self._translation_classes_dependent_data_mis(tgt_expansion,
-                                                             src_expansion)
+                src_expansion)
         return len(circulant_matrix_mis)
 
+    def preprocess_multipole_loopy_knl(self, tgt_expansion, src_expansion,
+            result_dtype):
+
+        circulant_matrix_mis, _, _ = \
+            self._translation_classes_dependent_data_mis(tgt_expansion,
+                src_expansion)
+        circulant_matrix_ident_to_index = dict((ident, i) for i, ident in
+            enumerate(circulant_matrix_mis))
+
+        ncoeff_src = len(src_expansion.get_coefficient_identifiers())
+        ncoeff_preprocessed = self.preprocess_multipole_nexprs(tgt_expansion,
+            src_expansion)
+
+        output_coeffs = pymbolic.var("output_coeffs")
+        input_coeffs = pymbolic.var("input_coeffs")
+        ioutput_coeff = pymbolic.var("ioutput_coeff")
+
+        domains = [
+            "{[ioutput_coeff]: 0<=ioutput_coeff<noutput_coeffs}",
+        ]
+        insns = [
+            lp.Assignment(
+                assignee=output_coeffs[ioutput_coeff],
+                expression=0,
+                id="init",
+            )
+        ]
+        prev_insn = "init"
+        for icoeff_src, term in enumerate(
+                src_expansion.get_coefficient_identifiers()):
+            new_icoeff_src = circulant_matrix_ident_to_index[term]
+            insns += [
+                lp.Assignment(
+                    assignee=output_coeffs[new_icoeff_src],
+                    expression=input_coeffs[icoeff_src],
+                    id=f"coeff_insn_{icoeff_src}",
+                    depends_on=frozenset([prev_insn])
+                ),
+            ]
+            prev_insn = f"coeff_insn_{icoeff_src}"
+
+        return lp.make_function(domains, insns,
+            kernel_data=[
+                lp.ValueArg("src_rscale", None),
+                lp.GlobalArg("output_coeffs", None, shape=ncoeff_preprocessed,
+                    is_input=False, is_output=True),
+                lp.GlobalArg("input_coeffs", None, shape=ncoeff_src),
+                ...],
+            name="m2l_preprocess_inner",
+            lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+            fixed_parameters=dict(noutput_coeffs=ncoeff_preprocessed),
+        )
+
     def postprocess_local_exprs(self, tgt_expansion, src_expansion, m2l_result,
             src_rscale, tgt_rscale, sac):
-        circulant_matrix_mis, needed_vector_terms, max_mi = \
+        circulant_matrix_mis, _, _ = \
                 self._translation_classes_dependent_data_mis(tgt_expansion,
                                                                  src_expansion)
         circulant_matrix_ident_to_index = {ident: i for i, ident in
@@ -440,6 +506,99 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase):
         return self.translation_classes_dependent_ndata(
             tgt_expansion, src_expansion)
 
+    def postprocess_local_loopy_knl(self, tgt_expansion, src_expansion,
+            result_dtype):
+        circulant_matrix_mis, needed_vector_terms, _ = \
+            self._translation_classes_dependent_data_mis(tgt_expansion,
+                src_expansion)
+        circulant_matrix_ident_to_index = dict((ident, i) for i, ident in
+                            enumerate(circulant_matrix_mis))
+
+        ncoeff_tgt = len(tgt_expansion.get_coefficient_identifiers())
+        ncoeff_before_postprocessed = self.postprocess_local_nexprs(tgt_expansion,
+                src_expansion)
+        order = tgt_expansion.order
+
+        fixed_parameters = {
+            "ncoeff_tgt": ncoeff_tgt,
+            "ncoeff_before_postprocessed": ncoeff_before_postprocessed,
+            "order": order,
+        }
+
+        domains = [
+            "{[iorder]: 0<iorder<=order}"
+        ]
+
+        insns = ["<> rscale_ratio = tgt_rscale / src_rscale {id=rscale_ratio}"]
+
+        rscale_arr = pymbolic.var("rscale_arr")
+        rscale_ratio = pymbolic.var("rscale_ratio")
+        iorder = pymbolic.var("iorder")
+
+        insns += [
+            lp.Assignment(
+                assignee=rscale_arr[0],
+                expression=1,
+                id="rscale_arr0",
+                depends_on="rscale_ratio",
+            ),
+            lp.Assignment(
+                assignee=rscale_arr[iorder],
+                expression=rscale_arr[iorder - 1]*rscale_ratio,
+                id="rscale_arr",
+                depends_on="rscale_arr0",
+            ),
+        ]
+
+        output_coeffs = pymbolic.var("output_coeffs")
+        input_coeffs = pymbolic.var("input_coeffs")
+
+        if self.use_fft and result_dtype in \
+                (np.float64, np.float32):
+            result_func = pymbolic.var("real")
+        else:
+            def result_func(x):
+                return x
+
+        for new_icoeff_tgt, term in enumerate(
+                tgt_expansion.get_coefficient_identifiers()):
+            if self.use_fft:
+                # since we reversed the M2L matrix, we reverse the result
+                # to get the correct result
+                n = len(circulant_matrix_mis)
+                icoeff_tgt = n - 1 - circulant_matrix_ident_to_index[term]
+            else:
+                icoeff_tgt = circulant_matrix_ident_to_index[term]
+
+            insns += [
+                lp.Assignment(
+                    assignee=output_coeffs[new_icoeff_tgt],
+                    expression=result_func(
+                        input_coeffs[icoeff_tgt]) * rscale_arr[sum(term)],
+                    id=f"coeff_insn_{new_icoeff_tgt}",
+                    depends_on="rscale_arr",
+                )
+            ]
+
+        return lp.make_function(domains, insns,
+            kernel_data=[
+                lp.ValueArg("src_rscale", None),
+                lp.ValueArg("tgt_rscale", None),
+                lp.GlobalArg("output_coeffs", None,
+                    shape=ncoeff_tgt, is_input=False,
+                    is_output=True),
+                lp.GlobalArg("input_coeffs", None,
+                    shape=ncoeff_before_postprocessed,
+                    is_output=False, is_input=True),
+                lp.TemporaryVariable("rscale_arr",
+                    None,
+                    shape=(order + 1,)),
+                ...],
+            name="m2l_postprocess_inner",
+            lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+            fixed_parameters=fixed_parameters,
+        )
+
 
 # }}} VolumeTaylorM2LTranslation
 
@@ -468,7 +627,7 @@ class VolumeTaylorM2LWithPreprocessedMultipoles(VolumeTaylorM2LTranslation):
         icoeff_tgt = pymbolic.var("icoeff_tgt")
         domains = [f"{{[icoeff_tgt]: 0<=icoeff_tgt<{ncoeff_tgt} }}"]
 
-        coeff = pymbolic.var("coeff")
+        tgt_coeffs = pymbolic.var("tgt_coeffs")
         src_coeffs = pymbolic.var("src_coeffs")
         translation_classes_dependent_data = pymbolic.var("data")
 
@@ -487,14 +646,17 @@ class VolumeTaylorM2LWithPreprocessedMultipoles(VolumeTaylorM2LTranslation):
 
         insns = [
             lp.Assignment(
-                assignee=coeff[icoeff_tgt],
-                expression=coeff[icoeff_tgt] + expr),
+                assignee=tgt_coeffs[icoeff_tgt],
+                expression=tgt_coeffs[icoeff_tgt] + expr
+            ),
         ]
         return lp.make_function(domains, insns,
                 kernel_data=[
-                    lp.GlobalArg("coeff, src_coeffs, data",
-                        shape=lp.auto),
-                    lp.ValueArg("src_rscale, tgt_rscale"),
+                    lp.GlobalArg("tgt_coeffs", shape=lp.auto, is_input=True,
+                        is_output=True),
+                    lp.GlobalArg("src_coeffs, data",
+                        shape=lp.auto, is_input=True, is_output=False),
+                    lp.ValueArg("src_rscale, tgt_rscale", is_input=True),
                     ...],
                 name="e2e",
                 lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
@@ -519,7 +681,13 @@ class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles):
 
     def translation_classes_dependent_data(self, tgt_expansion, src_expansion,
             src_rscale, dvec, sac):
+        """Return an iterable of expressions that needs to be precomputed
+        for multipole-to-local translations that depend only on the
+        distance between the multipole center and the local center which
+        is given as *dvec*.
 
+        The final result should be transformed using an FFT.
+        """
         derivatives_full = super().translation_classes_dependent_data(
             tgt_expansion, src_expansion, src_rscale, dvec, sac)
         # Note that the matrix we have now is a mirror image of a
@@ -527,14 +695,7 @@ class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles):
         # first column for the circulant matrix and then finally
         # use the FFT for convolution represented by the circulant
         # matrix.
-        return fft(list(reversed(derivatives_full)), sac=sac)
-
-    def preprocess_multipole_exprs(self, tgt_expansion, src_expansion,
-            src_coeff_exprs, sac, src_rscale):
-        input_vector = super().preprocess_multipole_exprs(
-            tgt_expansion, src_expansion, src_coeff_exprs, sac, src_rscale)
-
-        return fft(input_vector, sac=sac)
+        return list(reversed(derivatives_full))
 
     def postprocess_local_exprs(self, tgt_expansion, src_expansion, m2l_result,
             src_rscale, tgt_rscale, sac):
@@ -542,7 +703,6 @@ class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles):
                 self._translation_classes_dependent_data_mis(tgt_expansion,
                                                                  src_expansion)
         n = len(circulant_matrix_mis)
-        m2l_result = fft(m2l_result, inverse=True, sac=sac)
         # since we reversed the M2L matrix, we reverse the result
         # to get the correct result
         m2l_result = list(reversed(m2l_result[:n]))
@@ -665,7 +825,7 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation):
         icoeff_tgt = pymbolic.var("icoeff_tgt")
         domains = [f"{{[icoeff_tgt]: 0<=icoeff_tgt<{ncoeff_tgt} }}"]
 
-        coeff = pymbolic.var("coeff")
+        tgt_coeffs = pymbolic.var("tgt_coeffs")
         src_coeffs = pymbolic.var("src_coeffs")
         translation_classes_dependent_data = pymbolic.var("data")
 
@@ -681,14 +841,16 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation):
 
         insns = [
             lp.Assignment(
-                assignee=coeff[icoeff_tgt],
-                expression=coeff[icoeff_tgt] + expr),
+                assignee=tgt_coeffs[icoeff_tgt],
+                expression=tgt_coeffs[icoeff_tgt] + expr),
         ]
         return lp.make_function(domains, insns,
                 kernel_data=[
-                    lp.GlobalArg("coeff, src_coeffs, data",
-                        shape=lp.auto),
-                    lp.ValueArg("src_rscale, tgt_rscale"),
+                    lp.GlobalArg("tgt_coeffs", shape=lp.auto, is_input=True,
+                        is_output=True),
+                    lp.GlobalArg("src_coeffs, data",
+                        shape=lp.auto, is_input=True, is_output=False),
+                    lp.ValueArg("src_rscale, tgt_rscale", is_input=True),
                     ...],
                 name="e2e",
                 lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
@@ -726,7 +888,6 @@ class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles):
     def translation_classes_dependent_data(self, tgt_expansion, src_expansion,
             src_rscale, dvec, sac):
 
-        from sumpy.tools import fft
         translation_classes_dependent_data = \
             super().translation_classes_dependent_data(tgt_expansion,
                 src_expansion, src_rscale, dvec, sac)
@@ -748,26 +909,94 @@ class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles):
 
         first_column_circulant = list(first_column_toeplitz) + \
                 list(reversed(first_row_toeplitz))
-        return fft(first_column_circulant, sac)
+        return first_column_circulant
 
     def preprocess_multipole_exprs(self, tgt_expansion, src_expansion,
             src_coeff_exprs, sac, src_rscale):
 
-        from sumpy.tools import fft
         result = super().preprocess_multipole_exprs(tgt_expansion,
             src_expansion, src_coeff_exprs, sac, src_rscale)
 
         result = list(reversed(result))
         result += [0] * (len(result) - 1)
-        return fft(result, sac=sac)
+        return result
 
     def postprocess_local_exprs(self, tgt_expansion, src_expansion,
             m2l_result, src_rscale, tgt_rscale, sac):
 
-        m2l_result = fft(m2l_result, inverse=True, sac=sac)
         m2l_result = m2l_result[:2*tgt_expansion.order+1]
         return super().postprocess_local_exprs(tgt_expansion,
             src_expansion, m2l_result, src_rscale, tgt_rscale, sac)
 
+
 # }}} FourierBesselM2LWithFFT
+
+# {{{ translation_classes_dependent_data_loopy_knl
+
+def translation_classes_dependent_data_loopy_knl(tgt_expansion, src_expansion,
+            result_dtype):
+    """
+    This is a helper function to create a loopy kernel to generate translation
+    classes dependent data. This function uses symbolic expressions given by the
+    M2L translation, converts them to pymbolic expressions and generates a loopy
+    kernel. Note that the loopy kernel returned has lots of expressions in it and
+    takes a long time. Therefore, this function should be used only as a fallback
+    when there is no "loop-y" kernel to calculate the data.
+    """
+    src_rscale = sym.Symbol("src_rscale")
+    dvec = sym.make_sym_vector("d", tgt_expansion.dim)
+    from sumpy.assignment_collection import SymbolicAssignmentCollection
+    sac = SymbolicAssignmentCollection()
+    derivatives = tgt_expansion.m2l_translation.translation_classes_dependent_data(
+        tgt_expansion, src_expansion, src_rscale, dvec, sac)
+
+    vec_name = "m2l_translation_classes_dependent_data"
+
+    tgt_coeff_names = [
+            sac.assign_unique("m2l_translation_classes_dependent_data%d" % i,
+                coeff_i)
+            for i, coeff_i in enumerate(derivatives)]
+    sac.run_global_cse()
+
+    from sumpy.codegen import to_loopy_insns
+    from sumpy.tools import to_complex_dtype
+    insns = to_loopy_insns(
+            sac.assignments.items(),
+            vector_names=set(["d"]),
+            pymbolic_expr_maps=[tgt_expansion.get_code_transformer()],
+            retain_names=tgt_coeff_names,
+            complex_dtype=to_complex_dtype(result_dtype),
+            )
+
+    data = pymbolic.var("m2l_translation_classes_dependent_data")
+    depends_on = None
+    for i in range(len(insns)):
+        insn = insns[i]
+        if isinstance(insn, lp.Assignment) and \
+                insn.assignee.name.startswith(vec_name):
+            idx = int(insn.assignee.name[len(vec_name):])
+            insns[i] = lp.Assignment(
+                assignee=data[idx],
+                expression=insn.expression,
+                id=f"data_{idx}",
+                depends_on=depends_on,
+            )
+            depends_on = frozenset([f"data_{idx}"])
+
+    knl = lp.make_function([], insns,
+        kernel_data=[
+            lp.ValueArg("src_rscale", None),
+            lp.GlobalArg("d", None, shape=tgt_expansion.dim),
+            lp.GlobalArg(data.name, None,
+                shape=len(derivatives), is_input=False,
+                is_output=True),
+        ],
+        name="m2l_data",
+        lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+    )
+
+    return knl
+
+# }}} translation_classes_dependent_data_loopy_knl
+
 # vim: fdm=marker
diff --git a/sumpy/fmm.py b/sumpy/fmm.py
index 42564444..95afd2cc 100644
--- a/sumpy/fmm.py
+++ b/sumpy/fmm.py
@@ -41,7 +41,10 @@ from sumpy import (
         E2EFromChildren, E2EFromParent,
         M2LGenerateTranslationClassesDependentData,
         M2LPreprocessMultipole, M2LPostprocessLocal)
-from sumpy.tools import to_complex_dtype
+from sumpy.tools import (to_complex_dtype, AggregateProfilingEvent,
+        run_opencl_fft, get_opencl_fft_app)
+
+from typing import TypeVar, List, Union
 
 
 # {{{ tree-independent data for wrangler
@@ -176,6 +179,11 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
                           exclude_self=self.exclude_self,
                           strength_usage=self.strength_usage)
 
+    @memoize_method
+    def opencl_fft_app(self, shape, dtype):
+        with cl.CommandQueue(self.cl_context) as queue:
+            return get_opencl_fft_app(queue, shape, dtype)
+
 # }}}
 
 
@@ -184,16 +192,28 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
 _SECONDS_PER_NANOSECOND = 1e-9
 
 
+"""
+EventLike objects have an attribute native_event that returns
+a cl.Event that indicates the end of the event.
+"""
+EventLike = TypeVar("CLEventLike")
+
+
 class UnableToCollectTimingData(UserWarning):
     pass
 
 
 class SumpyTimingFuture:
 
-    def __init__(self, queue, events):
+    def __init__(self, queue, events: List[Union[cl.Event, EventLike]]):
         self.queue = queue
         self.events = events
 
+    @property
+    def native_events(self) -> List[cl.Event]:
+        return [evt if isinstance(evt, cl.Event) else evt.native_event
+                for evt in self.events]
+
     @memoize_method
     def result(self):
         from boxtree.timing import TimingResult
@@ -208,7 +228,7 @@ class SumpyTimingFuture:
             return TimingResult(wall_elapsed=None)
 
         if self.events:
-            pyopencl.wait_for_events(self.events)
+            pyopencl.wait_for_events(self.native_events)
 
         result = 0
         for event in self.events:
@@ -222,7 +242,7 @@ class SumpyTimingFuture:
         return all(
                 event.get_info(cl.event_info.COMMAND_EXECUTION_STATUS)
                 == cl.command_execution_status.COMPLETE
-                for event in self.events)
+                for event in self.native_events)
 
 # }}}
 
@@ -395,10 +415,18 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                 dtype=self.dtype)
 
     def m2l_translation_classes_dependent_data_zeros(self, queue):
-        return cl.array.zeros(
-                queue,
-                self.m2l_translation_classes_dependent_data_level_starts()[-1],
-                dtype=self.preprocessed_mpole_dtype)
+        result = []
+        for level in range(self.tree.nlevels):
+            expn_start, expn_stop = \
+                self.m2l_translation_classes_dependent_data_level_starts()[
+                    level:level+2]
+            translation_class_start, translation_class_stop = \
+                self.m2l_translation_class_level_start_box_nrs()[level:level+2]
+            exprs_level = cl.array.zeros(queue, expn_stop - expn_start,
+                                 dtype=self.preprocessed_mpole_dtype)
+            result.append(exprs_level.reshape(
+                            translation_class_stop - translation_class_start, -1))
+        return result
 
     def multipole_expansions_view(self, mpole_exps, level):
         expn_start, expn_stop = \
@@ -418,14 +446,10 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
 
     def m2l_translation_classes_dependent_data_view(self,
                 m2l_translation_classes_dependent_data, level):
-        expn_start, expn_stop = \
-            self.m2l_translation_classes_dependent_data_level_starts()[level:level+2]
-        translation_class_start, translation_class_stop = \
+        translation_class_start, _ = \
             self.m2l_translation_class_level_start_box_nrs()[level:level+2]
-
-        exprs_level = m2l_translation_classes_dependent_data[expn_start:expn_stop]
-        return (translation_class_start, exprs_level.reshape(
-                            translation_class_stop - translation_class_start, -1))
+        exprs_level = m2l_translation_classes_dependent_data[level]
+        return (translation_class_start, exprs_level)
 
     @memoize_method
     def m2l_preproc_mpole_expansions_level_starts(self):
@@ -440,18 +464,19 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                 level_starts=self.tree.level_start_box_nrs)
 
     def m2l_preproc_mpole_expansion_zeros(self, template_ary):
-        return cl.array.zeros(
-                template_ary.queue,
-                self.m2l_preproc_mpole_expansions_level_starts()[-1],
-                dtype=self.preprocessed_mpole_dtype)
-
-    def m2l_preproc_mpole_expansions_view(self, mpole_exps, level):
-        expn_start, expn_stop = \
+        result = []
+        for level in range(self.tree.nlevels):
+            expn_start, expn_stop = \
                 self.m2l_preproc_mpole_expansions_level_starts()[level:level+2]
-        box_start, box_stop = self.tree.level_start_box_nrs[level:level+2]
+            box_start, box_stop = self.tree.level_start_box_nrs[level:level+2]
+            exprs_level = cl.array.zeros(template_ary.queue, expn_stop - expn_start,
+                                 dtype=self.preprocessed_mpole_dtype)
+            result.append(exprs_level.reshape(box_stop - box_start, -1))
+        return result
 
-        return (box_start,
-                mpole_exps[expn_start:expn_stop].reshape(box_stop-box_start, -1))
+    def m2l_preproc_mpole_expansions_view(self, mpole_exps, level):
+        box_start, _ = self.tree.level_start_box_nrs[level:level+2]
+        return (box_start, mpole_exps[level])
 
     m2l_work_array_view = m2l_preproc_mpole_expansions_view
     m2l_work_array_zeros = m2l_preproc_mpole_expansion_zeros
@@ -528,6 +553,10 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
 
     # }}}
 
+    def run_opencl_fft(self, queue, input_vec, inverse, wait_for):
+        app = self.tree_indep.opencl_fft_app(input_vec.shape, input_vec.dtype)
+        return run_opencl_fft(app, queue, input_vec, inverse, wait_for)
+
     def form_multipoles(self,
             level_start_source_box_nrs, source_boxes,
             src_weight_vecs):
@@ -653,6 +682,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
 
     @memoize_method
     def multipole_to_local_precompute(self):
+        result = []
         with cl.CommandQueue(self.tree_indep.cl_context) as queue:
             m2l_translation_classes_dependent_data = \
                     self.m2l_translation_classes_dependent_data_zeros(queue)
@@ -672,6 +702,8 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                         m2l_translation_classes_dependent_data_view.shape[0]
 
                 if ntranslation_classes == 0:
+                    result.append(pyopencl.array.empty_like(
+                        m2l_translation_classes_dependent_data_view))
                     continue
 
                 data = self.translation_classes_data
@@ -689,13 +721,19 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                     ntranslation_vectors=m2l_translation_vectors.shape[1],
                     **self.kernel_extra_kwargs
                 )
-                m2l_translation_classes_dependent_data.add_event(evt)
 
-            m2l_translation_classes_dependent_data.finish()
+                if self.tree_indep.m2l_translation.use_fft:
+                    _, m2l_translation_classes_dependent_data_view = \
+                        self.run_opencl_fft(queue,
+                            m2l_translation_classes_dependent_data_view,
+                            inverse=False, wait_for=[evt])
+                result.append(m2l_translation_classes_dependent_data_view)
 
-            m2l_translation_classes_dependent_data = \
-                    m2l_translation_classes_dependent_data.with_queue(None)
-        return m2l_translation_classes_dependent_data
+            for lev in range(self.tree.nlevels):
+                result[lev].finish()
+
+            result = [arr.with_queue(None) for arr in result]
+        return result
 
     def _add_m2l_precompute_kwargs(self, kwargs_for_m2l,
             lev):
@@ -723,14 +761,33 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
             target_boxes, src_box_starts, src_box_lists,
             mpole_exps):
 
-        preprocess_evts = []
         queue = mpole_exps.queue
         local_exps = self.local_expansion_zeros(mpole_exps)
 
         if self.tree_indep.m2l_translation.use_preprocessing:
             preprocessed_mpole_exps = \
                 self.m2l_preproc_mpole_expansion_zeros(mpole_exps)
-            for lev in range(self.tree.nlevels):
+            m2l_work_array = self.m2l_work_array_zeros(local_exps)
+            mpole_exps_view_func = self.m2l_preproc_mpole_expansions_view
+            local_exps_view_func = self.m2l_work_array_view
+        else:
+            preprocessed_mpole_exps = mpole_exps
+            m2l_work_array = local_exps
+            mpole_exps_view_func = self.multipole_expansions_view
+            local_exps_view_func = self.local_expansions_view
+
+        preprocess_evts = []
+        translate_evts = []
+        postprocess_evts = []
+
+        for lev in range(self.tree.nlevels):
+            wait_for = []
+
+            start, stop = level_start_target_box_nrs[lev:lev+2]
+            if start == stop:
+                continue
+
+            if self.tree_indep.m2l_translation.use_preprocessing:
                 order = self.level_orders[lev]
                 preprocess_mpole_kernel = \
                     self.tree_indep.m2l_preprocess_mpole_kernel(order, order)
@@ -738,10 +795,6 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                 _, source_mpoles_view = \
                         self.multipole_expansions_view(mpole_exps, lev)
 
-                _, preprocessed_source_mpoles_view = \
-                        self.m2l_preproc_mpole_expansions_view(
-                                preprocessed_mpole_exps, lev)
-
                 tr_classes = self.m2l_translation_class_level_start_box_nrs()
                 if tr_classes[lev] == tr_classes[lev + 1]:
                     # There is no M2L happening in this level
@@ -750,33 +803,29 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                 evt, _ = preprocess_mpole_kernel(
                     queue,
                     src_expansions=source_mpoles_view,
-                    preprocessed_src_expansions=preprocessed_source_mpoles_view,
+                    preprocessed_src_expansions=preprocessed_mpole_exps[lev],
                     src_rscale=self.level_to_rscale(lev),
+                    wait_for=wait_for,
                     **self.kernel_extra_kwargs
                 )
-                preprocess_evts.append(evt)
-            mpole_exps = preprocessed_mpole_exps
-            m2l_work_array = self.m2l_work_array_zeros(local_exps)
-            mpole_exps_view_func = self.m2l_preproc_mpole_expansions_view
-            local_exps_view_func = self.m2l_work_array_view
-        else:
-            m2l_work_array = local_exps
-            mpole_exps_view_func = self.multipole_expansions_view
-            local_exps_view_func = self.local_expansions_view
+                wait_for.append(evt)
 
-        translate_evts = []
+                if self.tree_indep.m2l_translation.use_fft:
+                    evt_fft, preprocessed_mpole_exps[lev] = \
+                        self.run_opencl_fft(queue,
+                            preprocessed_mpole_exps[lev],
+                            inverse=False, wait_for=wait_for)
+                    wait_for.append(evt_fft.native_event)
+                    evt = AggregateProfilingEvent([evt, evt_fft])
 
-        for lev in range(self.tree.nlevels):
-            start, stop = level_start_target_box_nrs[lev:lev+2]
-            if start == stop:
-                continue
+                preprocess_evts.append(evt)
 
             order = self.level_orders[lev]
             m2l = self.tree_indep.m2l(order, order,
                     self.supports_translation_classes)
 
             source_level_start_ibox, source_mpoles_view = \
-                    mpole_exps_view_func(mpole_exps, lev)
+                    mpole_exps_view_func(preprocessed_mpole_exps, lev)
             target_level_start_ibox, target_locals_view = \
                     local_exps_view_func(m2l_work_array, lev)
 
@@ -801,14 +850,11 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                     kwargs["m2l_translation_classes_dependent_data"].size == 0:
                 # There is nothing to do for this level
                 continue
-            evt, _ = m2l(queue, **kwargs, wait_for=preprocess_evts)
-
+            evt, _ = m2l(queue, **kwargs, wait_for=wait_for)
+            wait_for.append(evt)
             translate_evts.append(evt)
 
-        postprocess_evts = []
-
-        if self.tree_indep.m2l_translation.use_preprocessing:
-            for lev in range(self.tree.nlevels):
+            if self.tree_indep.m2l_translation.use_preprocessing:
                 order = self.level_orders[lev]
                 postprocess_local_kernel = \
                     self.tree_indep.m2l_postprocess_local_kernel(order, order)
@@ -825,6 +871,13 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                     # There is no M2L happening in this level
                     continue
 
+                if self.tree_indep.m2l_translation.use_fft:
+                    evt_fft, target_locals_before_postprocessing_view = \
+                        self.run_opencl_fft(queue,
+                            target_locals_before_postprocessing_view,
+                            inverse=True, wait_for=wait_for)
+                    wait_for.append(evt_fft.native_event)
+
                 evt, _ = postprocess_local_kernel(
                     queue,
                     tgt_expansions=target_locals_view,
@@ -832,10 +885,14 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                         target_locals_before_postprocessing_view),
                     src_rscale=self.level_to_rscale(lev),
                     tgt_rscale=self.level_to_rscale(lev),
-                    wait_for=translate_evts,
+                    wait_for=wait_for,
                     **self.kernel_extra_kwargs,
                 )
-                postprocess_evts.append(evt)
+
+                if self.tree_indep.m2l_translation.use_fft:
+                    postprocess_evts.append(AggregateProfilingEvent([evt, evt_fft]))
+                else:
+                    postprocess_evts.append(evt)
 
         timing_events = preprocess_evts + translate_evts + postprocess_evts
 
diff --git a/sumpy/p2p.py b/sumpy/p2p.py
index f70ecedf..7944e556 100644
--- a/sumpy/p2p.py
+++ b/sumpy/p2p.py
@@ -468,7 +468,6 @@ class P2PFromCSR(P2PBase):
             "{[iknl]: 0 <= iknl < noutputs}",
             "{[isrc_box]: isrc_box_start <= isrc_box < isrc_box_end}",
             "{[idim]: 0 <= idim < dim}",
-            "{[istrength]: 0 <= istrength < nstrengths}",
             "{[isrc]: isrc_start <= isrc < isrc_end}"
         ]
 
@@ -483,6 +482,7 @@ class P2PFromCSR(P2PBase):
                     shape=(self.strength_count, max_nsources_in_one_box)),
             ]
             domains += [
+                "{[istrength]: 0 <= istrength < nstrengths}",
                 "{[inner]: 0 <= inner < nsplit}",
                 "{[itgt_offset_outer]: 0 <= itgt_offset_outer <= tgt_outer_limit}",
                 "{[isrc_offset_outer]: 0 <= isrc_offset_outer <= src_outer_limit}",
diff --git a/sumpy/tools.py b/sumpy/tools.py
index 14152d70..4c6c5891 100644
--- a/sumpy/tools.py
+++ b/sumpy/tools.py
@@ -39,11 +39,13 @@ __doc__ = """
 from pytools import memoize_method
 from pytools.tag import Tag, tag_dataclass
 import numbers
-from collections import defaultdict
+from collections import defaultdict, namedtuple
 from pymbolic.mapper import WalkMapper
 
 import numpy as np
 import sumpy.symbolic as sym
+import pyopencl as cl
+import pyopencl.array as cla
 
 import loopy as lp
 from typing import Dict, Tuple, Any
@@ -931,6 +933,96 @@ def to_complex_dtype(dtype):
     except KeyError:
         raise RuntimeError(f"Unknown dtype: {dtype}")
 
+
+ProfileGetter = namedtuple("ProfileGetter", "start, end")
+
+
+class AggregateProfilingEvent:
+    """An object to hold a list of events and provides compatibility
+    with some of the functionality of :class:`pyopencl.Event`.
+    Assumes that the last event waits on all of the previous events.
+    """
+    def __init__(self, events):
+        self.events = events[:]
+        if isinstance(events[-1], cl.Event):
+            self.native_event = events[-1]
+        else:
+            self.native_event = events[-1].native_event
+
+    @property
+    def profile(self):
+        total = sum(evt.profile.end - evt.profile.start for evt in self.events)
+        end = self.native_event.profile.end
+        return ProfileGetter(start=end - total, end=end)
+
+    def wait(self):
+        return self.native_event.wait()
+
+
+class MarkerBasedProfilingEvent:
+    """An object to hold two marker events and provides compatibility
+    with some of the functionality of :class:`pyopencl.Event`.
+    """
+    def __init__(self, *, end_event, start_event):
+        self.native_event = end_event
+        self.start_event = start_event
+
+    @property
+    def profile(self):
+        return ProfileGetter(start=self.start_event.profile.start,
+                             end=self.native_event.profile.end)
+
+    def wait(self):
+        return self.native_event.wait()
+
+
+def get_opencl_fft_app(queue, shape, dtype):
+    """Setup an object for out-of-place FFT on with given shape and dtype
+    on given queue. Only supports in-order queues.
+    """
+    if queue.properties & cl.command_queue_properties.OUT_OF_ORDER_EXEC_MODE_ENABLE:
+        raise RuntimeError("VkFFT does not support out of order queues yet.")
+
+    assert dtype.type in (np.float32, np.float64, np.complex64,
+                           np.complex128)
+
+    from pyvkfft.opencl import VkFFTApp
+    app = VkFFTApp(shape=shape, dtype=dtype, queue=queue, ndim=1, inplace=False)
+    return app
+
+
+def run_opencl_fft(vkfft_app, queue, input_vec, inverse=False, wait_for=None):
+    """Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent`
+    that indicate the end and start of the operations carried out and the output
+    vector.
+    Only supports in-order queues.
+    """
+    if wait_for is None:
+        wait_for = []
+
+    start_evt = cl.enqueue_marker(queue, wait_for=wait_for[:])
+
+    if vkfft_app.inplace:
+        raise RuntimeError("inplace fft is not supported")
+    else:
+        output_vec = cla.empty_like(input_vec, queue)
+
+    # FIXME: use the public API once https://github.com/vincefn/pyvkfft/pull/17 is in
+    from pyvkfft.opencl import _vkfft_opencl
+    if inverse:
+        meth = _vkfft_opencl.ifft
+    else:
+        meth = _vkfft_opencl.fft
+
+    meth(vkfft_app.app, int(input_vec.data.int_ptr), int(output_vec.data.int_ptr),
+        int(queue.int_ptr))
+
+    end_evt = cl.enqueue_marker(queue, wait_for=[start_evt])
+    output_vec.add_event(end_evt)
+
+    return (MarkerBasedProfilingEvent(end_event=end_evt, start_event=start_evt),
+        output_vec)
+
 # }}}
 
 # vim: fdm=marker
diff --git a/test/test_fmm.py b/test/test_fmm.py
index 3e8ebf0b..a63b64fd 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -409,7 +409,8 @@ def test_unified_single_and_double(ctx_factory):
     assert rel_err < 1e-12
 
 
-def test_sumpy_fmm_timing_data_collection(ctx_factory):
+@pytest.mark.parametrize("use_fft", [True, False])
+def test_sumpy_fmm_timing_data_collection(ctx_factory, use_fft):
     logging.basicConfig(level=logging.INFO)
 
     ctx = ctx_factory()
@@ -448,10 +449,20 @@ def test_sumpy_fmm_timing_data_collection(ctx_factory):
 
     from functools import partial
 
+    if use_fft:
+        from sumpy.expansion.m2l import FFTM2LTranslationClassFactory
+        m2l_translation_factory = FFTM2LTranslationClassFactory()
+    else:
+        from sumpy.expansion.m2l import NonFFTM2LTranslationClassFactory
+        m2l_translation_factory = NonFFTM2LTranslationClassFactory()
+
+    m2l_translation = m2l_translation_factory.get_m2l_translation_class(
+                knl, local_expn_class)()
+
     tree_indep = SumpyTreeIndependentDataForWrangler(
             ctx,
             partial(mpole_expn_class, knl),
-            partial(local_expn_class, knl),
+            partial(local_expn_class, knl, m2l_translation=m2l_translation),
             target_kernels)
 
     wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
-- 
GitLab