From 15aa19296f733f456188ecda007ff177cbe32b78 Mon Sep 17 00:00:00 2001
From: Isuru Fernando <isuruf@gmail.com>
Date: Wed, 17 May 2023 14:13:32 -0500
Subject: [PATCH] Refactor E2P and P2E (#153)

* Use pytential branch

* Refactor E2P

* try new loopy branch

* fix formatting

* disable domains check

* register only if not found

* Move kernel_scaling to the outer kernel

* Refactor P2E

* Use loopy main

* re-enable implemented domains check

* Rename some loopy kernel handling functions

---------

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
---
 .github/workflows/ci.yml    |   6 +-
 sumpy/codegen.py            |   6 +-
 sumpy/e2p.py                | 175 +++++++++++++--------------
 sumpy/expansion/__init__.py |  25 +++-
 sumpy/expansion/loopy.py    | 227 ++++++++++++++++++++++++++++++++++++
 sumpy/kernel.py             |  30 +++--
 sumpy/p2e.py                | 145 ++++++++++++-----------
 7 files changed, 439 insertions(+), 175 deletions(-)
 create mode 100644 sumpy/expansion/loopy.py

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e88c47b7..893a8b7d 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -101,9 +101,9 @@ jobs:
             run: |
                 curl -L -O https://tiker.net/ci-support-v0
                 . ./ci-support-v0
-                # if [[ "$DOWNSTREAM_PROJECT" == "pytential" && "$GITHUB_HEAD_REF" == "fft" ]]; then
-                #    DOWNSTREAM_PROJECT=https://github.com/isuruf/pytential.git@pyvkfft
-                # fi
+                if [[ "$DOWNSTREAM_PROJECT" == "pytential" && "$GITHUB_HEAD_REF" == "e2p" ]]; then
+                    DOWNSTREAM_PROJECT=https://github.com/isuruf/pytential.git@e2p
+                fi
                 test_downstream "$DOWNSTREAM_PROJECT"
 
 # vim: sw=4
diff --git a/sumpy/codegen.py b/sumpy/codegen.py
index 147dc499..fa13d526 100644
--- a/sumpy/codegen.py
+++ b/sumpy/codegen.py
@@ -200,9 +200,11 @@ class Hankel1_01(lp.ScalarCallable):  # noqa: N801
 
 def register_bessel_callables(loopy_knl):
     from sumpy.codegen import BesselJvvp1, Hankel1_01
-    loopy_knl = lp.register_callable(loopy_knl, "bessel_jvvp1",
+    if "bessel_jvvp1" not in loopy_knl.callables_table:
+        loopy_knl = lp.register_callable(loopy_knl, "bessel_jvvp1",
             BesselJvvp1("bessel_jvvp1"))
-    loopy_knl = lp.register_callable(loopy_knl, "hank1_01",
+    if "hank1_01" not in loopy_knl.callables_table:
+        loopy_knl = lp.register_callable(loopy_knl, "hank1_01",
             Hankel1_01("hank1_01"))
     return loopy_knl
 
diff --git a/sumpy/e2p.py b/sumpy/e2p.py
index a62236b0..f304a339 100644
--- a/sumpy/e2p.py
+++ b/sumpy/e2p.py
@@ -24,9 +24,8 @@ from abc import ABC, abstractmethod
 
 import numpy as np
 import loopy as lp
-import sumpy.symbolic as sym
 
-from sumpy.tools import KernelCacheMixin
+from sumpy.tools import KernelCacheMixin, gather_loopy_arguments
 from loopy.version import MOST_RECENT_LANGUAGE_VERSION
 
 
@@ -82,55 +81,32 @@ class E2PBase(KernelCacheMixin, ABC):
     def default_name(self):
         pass
 
-    def get_loopy_insns_and_result_names(self):
-        from sumpy.symbolic import make_sym_vector
-        bvec = make_sym_vector("b", self.dim)
-
-        import sumpy.symbolic as sp
-        rscale = sp.Symbol("rscale")
-
-        from sumpy.assignment_collection import SymbolicAssignmentCollection
-        sac = SymbolicAssignmentCollection()
-
-        coeff_exprs = [
-                sym.Symbol(f"coeff{i}")
-                for i in range(len(self.expansion.get_coefficient_identifiers()))]
-
-        result_names = [
-            sac.assign_unique(f"result_{i}_p",
-                self.expansion.evaluate(knl, coeff_exprs, bvec, rscale, sac=sac))
-            for i, knl in enumerate(self.kernels)
-            ]
-
-        sac.run_global_cse()
+    def get_cache_key(self):
+        return (type(self).__name__, self.expansion, tuple(self.kernels))
 
-        from sumpy.codegen import to_loopy_insns
-        loopy_insns = to_loopy_insns(
-                sac.assignments.items(),
-                vector_names={"b"},
-                pymbolic_expr_maps=[
-                    knl.get_code_transformer() for knl in self.kernels],
-                retain_names=result_names,
-                complex_dtype=np.complex128  # FIXME
-                )
+    def add_loopy_eval_callable(
+            self, loopy_knl: lp.TranslationUnit) -> lp.TranslationUnit:
+        inner_knl = self.expansion.get_loopy_evaluator(self.kernels)
+        loopy_knl = lp.merge([loopy_knl, inner_knl])
+        loopy_knl = lp.inline_callable_kernel(loopy_knl, "e2p")
+        loopy_knl = lp.remove_unused_inames(loopy_knl)
+        for kernel in self.kernels:
+            loopy_knl = kernel.prepare_loopy_kernel(loopy_knl)
+        loopy_knl = lp.tag_array_axes(loopy_knl, "targets", "sep,C")
+        return loopy_knl
 
-        return loopy_insns, result_names
+    def get_loopy_args(self):
+        return gather_loopy_arguments((self.expansion,) + tuple(self.kernels))
 
     def get_kernel_scaling_assignment(self):
         from sumpy.symbolic import SympyToPymbolicMapper
-        from sumpy.tools import ScalingAssignmentTag
         sympy_conv = SympyToPymbolicMapper()
-        return [lp.Assignment(id=None,
+        return [lp.Assignment(id="kernel_scaling",
                     assignee="kernel_scaling",
                     expression=sympy_conv(
                         self.expansion.kernel.get_global_scaling_const()),
                     temp_var_type=lp.Optional(None),
-                    tags=frozenset([ScalingAssignmentTag()]),
                     )]
-
-    def get_cache_key(self):
-        return (type(self).__name__, self.expansion, tuple(self.kernels))
-
 # }}}
 
 
@@ -143,14 +119,15 @@ class E2PFromSingleBox(E2PBase):
 
     def get_kernel(self):
         ncoeffs = len(self.expansion)
-
-        loopy_insns, result_names = self.get_loopy_insns_and_result_names()
+        loopy_args = self.get_loopy_args()
 
         loopy_knl = lp.make_kernel(
                 [
                     "{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
                     "{[itgt,idim]: itgt_start<=itgt<itgt_end and 0<=idim<dim}",
-                    ],
+                    "{[icoeff]: 0<=icoeff<ncoeffs}",
+                    "{[iknl]: 0<=iknl<nresults}",
+                ],
                 self.get_kernel_scaling_assignment()
                 + ["""
                 for itgt_box
@@ -160,27 +137,32 @@ class E2PFromSingleBox(E2PBase):
 
                     <> center[idim] = centers[idim, tgt_ibox] {id=fetch_center}
 
-                    """] + ["""
-                    <> coeff{coeffidx} = \
-                            src_expansions[tgt_ibox - src_base_ibox, {coeffidx}]
-                    """.format(coeffidx=i) for i in range(ncoeffs)] + ["""
+                    <> coeffs[icoeff] = \
+                            src_expansions[tgt_ibox - src_base_ibox, icoeff] \
+                            {id=fetch_coeffs}
 
                     for itgt
-                        <> b[idim] = targets[idim, itgt] - center[idim] {dup=idim}
-
-                        """] + loopy_insns + ["""
-
-                        result[{resultidx},itgt] = \
-                                kernel_scaling * result_{resultidx}_p \
-                                {{id_prefix=write_result}}
-                        """.format(resultidx=i) for i in range(len(result_names))
-                        ] + ["""
+                        <> tgt[idim] = targets[idim, itgt] {id=fetch_tgt,dup=idim}
+                        <> result_temp[iknl] = 0  {id=init_result,dup=iknl}
+                        [iknl]: result_temp[iknl] = e2p(
+                            [iknl]: result_temp[iknl],
+                            [icoeff]: coeffs[icoeff],
+                            [idim]: center[idim],
+                            [idim]: tgt[idim],
+                            rscale,
+                            itgt,
+                            ntargets,
+                            targets,
+                """ + ",".join(arg.name for arg in loopy_args) + """
+                        )  {dep=fetch_coeffs:fetch_center:init_result:fetch_tgt,\
+                                id=update_result}
+                        result[iknl, itgt] = result_temp[iknl] * kernel_scaling \
+                            {id=write_result,dep=update_result}
                     end
                 end
                 """],
                 [
-                    lp.GlobalArg("targets", None, shape=(self.dim, "ntargets"),
-                        dim_tags="sep,C"),
+                    lp.GlobalArg("targets", None, shape=(self.dim, "ntargets")),
                     lp.GlobalArg("box_target_starts,box_target_counts_nonchild",
                         None, shape=None),
                     lp.GlobalArg("centers", None, shape="dim, naligned_boxes"),
@@ -192,18 +174,20 @@ class E2PFromSingleBox(E2PBase):
                     lp.ValueArg("nsrc_level_boxes,naligned_boxes", np.int32),
                     lp.ValueArg("src_base_ibox", np.int32),
                     lp.ValueArg("ntargets", np.int32),
+                    *loopy_args,
                     "..."
-                ] + [arg.loopy_arg for arg in self.expansion.get_args()],
+                ],
                 name=self.name,
                 assumptions="ntgt_boxes>=1",
-                silenced_warnings="write_race(write_result*)",
+                silenced_warnings="write_race(*_result)",
                 default_offset=lp.auto,
-                fixed_parameters={"dim": self.dim, "nresults": len(result_names)},
+                fixed_parameters={"dim": self.dim, "nresults": len(self.kernels),
+                        "ncoeffs": ncoeffs},
                 lang_version=MOST_RECENT_LANGUAGE_VERSION)
 
         loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
-        for knl in self.kernels:
-            loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
+        loopy_knl = lp.tag_inames(loopy_knl, "iknl*:unr")
+        loopy_knl = self.add_loopy_eval_callable(loopy_knl)
 
         return loopy_knl
 
@@ -211,7 +195,7 @@ class E2PFromSingleBox(E2PBase):
         # FIXME
         knl = self.get_kernel()
         knl = lp.tag_inames(knl, {"itgt_box": "g.0"})
-        knl = self._allow_redundant_execution_of_knl_scaling(knl)
+        knl = lp.add_inames_to_insn(knl, "itgt_box", "id:kernel_scaling")
         knl = lp.set_options(knl,
                 enforce_variable_access_ordered="no_check")
 
@@ -247,8 +231,7 @@ class E2PFromCSR(E2PBase):
 
     def get_kernel(self):
         ncoeffs = len(self.expansion)
-
-        loopy_insns, result_names = self.get_loopy_insns_and_result_names()
+        loopy_args = self.get_loopy_args()
 
         loopy_knl = lp.make_kernel(
                 [
@@ -256,7 +239,9 @@ class E2PFromCSR(E2PBase):
                     "{[itgt]: itgt_start<=itgt<itgt_end}",
                     "{[isrc_box]: isrc_box_start<=isrc_box<isrc_box_end }",
                     "{[idim]: 0<=idim<dim}",
-                    ],
+                    "{[icoeff]: 0<=icoeff<ncoeffs}",
+                    "{[iknl]: 0<=iknl<nresults}",
+                ],
                 self.get_kernel_scaling_assignment()
                 + ["""
                 for itgt_box
@@ -265,35 +250,40 @@ class E2PFromCSR(E2PBase):
                     <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox]
 
                     for itgt
-                        <> tgt[idim] = targets[idim,itgt]
+                        <> tgt[idim] = targets[idim,itgt]  {id=fetch_tgt,dup=idim}
 
                         <> isrc_box_start = source_box_starts[itgt_box]
                         <> isrc_box_end = source_box_starts[itgt_box+1]
 
+                        <> result_temp[iknl] = 0 {id=init_result,dup=iknl}
                         for isrc_box
                             <> src_ibox = source_box_lists[isrc_box]
-                            """] + ["""
-                            <> coeff{coeffidx} = \
-                                src_expansions[src_ibox - src_base_ibox, {coeffidx}]
-                            """.format(coeffidx=i) for i in range(ncoeffs)] + ["""
-
-                            <> center[idim] = centers[idim, src_ibox] {dup=idim}
-                            <> b[idim] = tgt[idim] - center[idim] {dup=idim}
-
-                            """] + loopy_insns + ["""
+                            <> coeffs[icoeff] = \
+                                src_expansions[src_ibox - src_base_ibox, icoeff] \
+                                {id=fetch_coeffs,dup=icoeff}
+                            <> center[idim] = centers[idim, src_ibox] \
+                                {dup=idim,id=fetch_center}
+                            [iknl]: result_temp[iknl] = e2p(
+                                [iknl]: result_temp[iknl],
+                                [icoeff]: coeffs[icoeff],
+                                [idim]: center[idim],
+                                [idim]: tgt[idim],
+                                rscale,
+                                itgt,
+                                ntargets,
+                                targets,
+                """ + ",".join(arg.name for arg in loopy_args) + """
+                            )  {id=update_result, \
+                              dep=fetch_coeffs:fetch_center:fetch_tgt:init_result}
                         end
-                        """] + ["""
-                        result[{resultidx}, itgt] = result[{resultidx}, itgt] + \
-                                kernel_scaling * simul_reduce(sum, isrc_box,
-                                result_{resultidx}_p) {{id_prefix=write_result}}
-                        """.format(resultidx=i) for i in range(len(result_names))]
-                + ["""
+                        result[iknl, itgt] = result[iknl, itgt] + result_temp[iknl] \
+                                * kernel_scaling \
+                                {dep=update_result:init_result,id=write_result,dup=iknl}
                     end
                 end
                 """],
                 [
-                    lp.GlobalArg("targets", None, shape=(self.dim, "ntargets"),
-                        dim_tags="sep,C"),
+                    lp.GlobalArg("targets", None, shape=(self.dim, "ntargets")),
                     lp.GlobalArg("box_target_starts,box_target_counts_nonchild",
                         None, shape=None),
                     lp.GlobalArg("centers", None, shape="dim, aligned_nboxes"),
@@ -306,21 +296,24 @@ class E2PFromCSR(E2PBase):
                         dim_tags="sep,C"),
                     lp.GlobalArg("source_box_starts, source_box_lists,",
                         None, shape=None, offset=lp.auto),
+                    *loopy_args,
                     "..."
-                ] + [arg.loopy_arg for arg in self.expansion.get_args()],
+                ],
                 name=self.name,
                 assumptions="ntgt_boxes>=1",
-                silenced_warnings="write_race(write_result*)",
+                silenced_warnings="write_race(*_result)",
                 default_offset=lp.auto,
                 fixed_parameters={
+                        "ncoeffs": ncoeffs,
                         "dim": self.dim,
-                        "nresults": len(result_names)},
+                        "nresults": len(self.kernels)},
                 lang_version=MOST_RECENT_LANGUAGE_VERSION)
 
         loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
+        loopy_knl = lp.tag_inames(loopy_knl, "iknl*:unr")
         loopy_knl = lp.prioritize_loops(loopy_knl, "itgt_box,itgt,isrc_box")
-        for knl in self.kernels:
-            loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
+        loopy_knl = self.add_loopy_eval_callable(loopy_knl)
+        loopy_knl = lp.tag_array_axes(loopy_knl, "targets", "sep,C")
 
         return loopy_knl
 
@@ -328,7 +321,7 @@ class E2PFromCSR(E2PBase):
         # FIXME
         knl = self.get_kernel()
         knl = lp.tag_inames(knl, {"itgt_box": "g.0"})
-        knl = self._allow_redundant_execution_of_knl_scaling(knl)
+        knl = lp.add_inames_to_insn(knl, "itgt_box", "id:kernel_scaling")
         knl = lp.set_options(knl,
                 enforce_variable_access_ordered="no_check")
         return knl
diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py
index d4ef6e6a..3ee58917 100644
--- a/sumpy/expansion/__init__.py
+++ b/sumpy/expansion/__init__.py
@@ -21,9 +21,11 @@ THE SOFTWARE.
 """
 
 from abc import ABC, abstractmethod
-from typing import Any, ClassVar, Dict, Hashable, List, Optional, Tuple, Type
+from typing import (
+        Any, ClassVar, Dict, Hashable, List, Optional, Sequence, Tuple, Type)
 
 from pytools import memoize_method
+import loopy as lp
 
 import sumpy.symbolic as sym
 from sumpy.kernel import Kernel
@@ -63,7 +65,9 @@ class ExpansionBase(ABC):
     .. automethod:: get_coefficient_identifiers
     .. automethod:: coefficients_from_source
     .. automethod:: coefficients_from_source_vec
+    .. automethod:: get_loopy_expansion_formation
     .. automethod:: evaluate
+    .. automethod:: get_loopy_evaluator
 
     .. automethod:: with_kernel
     .. automethod:: copy
@@ -159,6 +163,17 @@ class ExpansionBase(ABC):
                 result[i] += weight * coeffs[i]
         return result
 
+    def get_loopy_expansion_formation(
+            self, kernels: Sequence[Kernel],
+            strength_usage: Sequence[int], nstrengths: int) -> lp.TranslationUnit:
+        """
+        :returns: a :mod:`loopy` kernel that returns the coefficients
+            for the expansion given by *kernels* with each kernel using
+            the strength given by *strength_usage*.
+        """
+        from sumpy.expansion.loopy import make_p2e_loopy_kernel
+        return make_p2e_loopy_kernel(self, kernels, strength_usage, nstrengths)
+
     @abstractmethod
     def evaluate(self, kernel, coeffs, bvec, rscale, sac=None):
         """
@@ -167,6 +182,14 @@ class ExpansionBase(ABC):
             in *coeffs*.
         """
 
+    def get_loopy_evaluator(self, kernels: Sequence[Kernel]) -> lp.TranslationUnit:
+        """
+        :returns: a :mod:`loopy` kernel that returns the evaluated
+            target transforms of the potential given by *kernels*.
+        """
+        from sumpy.expansion.loopy import make_e2p_loopy_kernel
+        return make_e2p_loopy_kernel(self, kernels)
+
     # }}}
 
     # {{{ copy
diff --git a/sumpy/expansion/loopy.py b/sumpy/expansion/loopy.py
new file mode 100644
index 00000000..00bbd3db
--- /dev/null
+++ b/sumpy/expansion/loopy.py
@@ -0,0 +1,227 @@
+__copyright__ = "Copyright (C) 2022 Isuru Fernando"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+from typing import Sequence
+import pymbolic
+import loopy as lp
+import numpy as np
+from sumpy.expansion import ExpansionBase
+from sumpy.kernel import Kernel
+import sumpy.symbolic as sym
+from sumpy.assignment_collection import SymbolicAssignmentCollection
+from sumpy.tools import gather_loopy_arguments, gather_loopy_source_arguments
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+def make_e2p_loopy_kernel(
+        expansion: ExpansionBase, kernels: Sequence[Kernel]) -> lp.TranslationUnit:
+    """
+    This is a helper function to create a loopy kernel for multipole/local
+    evaluation. This function uses symbolic expressions given by the expansion class,
+    converts them to pymbolic expressions and generates a loopy
+    kernel. Note that the loopy kernel returned has lots of expressions in it and
+    takes a long time. Therefore, this function should be used only as a fallback
+    when there is no "loop-y" kernel to evaluate the expansion.
+    """
+    dim = expansion.dim
+
+    bvec = sym.make_sym_vector("b", dim)
+    ncoeffs = len(expansion.get_coefficient_identifiers())
+
+    rscale = sym.Symbol("rscale")
+
+    sac = SymbolicAssignmentCollection()
+
+    domains = [
+        "{[idim]: 0<=idim<dim}",
+        "{[iknl]: 0<=iknl<nresults}",
+    ]
+    insns = []
+    insns.append(
+        lp.Assignment(
+            assignee="b[idim]",
+            expression="target[idim]-center[idim]",
+            temp_var_type=lp.Optional(None),
+        ))
+    target_args = gather_loopy_arguments((expansion,) + tuple(kernels))
+
+    coeff_exprs = sym.make_sym_vector("coeffs", ncoeffs)
+    coeff_names = [
+        sac.add_assignment(f"result{i}",
+            expansion.evaluate(knl, coeff_exprs, bvec, rscale, sac=sac))
+        for i, knl in enumerate(kernels)]
+
+    sac.run_global_cse()
+
+    code_transformers = [expansion.get_code_transformer()] \
+        + [kernel.get_code_transformer() for kernel in kernels]
+
+    from sumpy.codegen import to_loopy_insns
+    insns += to_loopy_insns(
+            sac.assignments.items(),
+            vector_names={"b", "coeffs"},
+            pymbolic_expr_maps=code_transformers,
+            retain_names=coeff_names,
+            complex_dtype=np.complex128  # FIXME
+            )
+
+    result = pymbolic.var("result")
+
+    # change result{i} = expr into result[i] += expr
+    for i in range(len(insns)):
+        insn = insns[i]
+        if isinstance(insn, lp.Assignment) and \
+                isinstance(insn.assignee, pymbolic.var) and \
+                insn.assignee.name.startswith(result.name):
+            idx = int(insn.assignee.name[len(result.name):])
+            insns[i] = lp.Assignment(
+                assignee=result[idx],
+                expression=result[idx] + insn.expression,
+                id=f"result_{idx}",
+                depends_on=insn.depends_on,
+            )
+
+    loopy_knl = lp.make_function(domains, insns,
+            kernel_data=[
+                lp.GlobalArg("result", shape=(len(kernels),), is_input=True,
+                    is_output=True),
+                lp.GlobalArg("coeffs",
+                    shape=(ncoeffs,), is_input=True, is_output=False),
+                lp.GlobalArg("center, target",
+                    shape=(dim,), is_input=True, is_output=False),
+                lp.ValueArg("rscale", is_input=True),
+                lp.ValueArg("itgt", is_input=True),
+                lp.ValueArg("ntargets", is_input=True),
+                lp.GlobalArg("targets",
+                    shape=(dim, "ntargets"), is_input=True, is_output=False),
+                *target_args,
+                ...],
+            name="e2p",
+            lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+            fixed_parameters={"dim": dim, "nresults": len(kernels)},
+            )
+
+    loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
+    for kernel in kernels:
+        loopy_knl = kernel.prepare_loopy_kernel(loopy_knl)
+
+    return loopy_knl
+
+
+def make_p2e_loopy_kernel(
+        expansion: ExpansionBase, kernels: Sequence[Kernel],
+        strength_usage: Sequence[int], nstrengths: int) -> lp.TranslationUnit:
+    """
+    This is a helper function to create a loopy kernel for multipole/local
+    expression. This function uses symbolic expressions given by the expansion
+    class, converts them to pymbolic expressions and generates a loopy
+    kernel. Note that the loopy kernel returned has lots of expressions in it and
+    takes a long time. Therefore, this function should be used only as a fallback
+    when there is no "loop-y" kernel to evaluate the expansion.
+    """
+    dim = expansion.dim
+
+    avec = sym.make_sym_vector("a", dim)
+    ncoeffs = len(expansion.get_coefficient_identifiers())
+
+    rscale = sym.Symbol("rscale")
+
+    sac = SymbolicAssignmentCollection()
+
+    domains = [
+        "{[idim]: 0<=idim<dim}",
+    ]
+    insns = []
+    insns.append(
+        lp.Assignment(
+            assignee="a[idim]",
+            expression="center[idim]-source[idim]",
+            temp_var_type=lp.Optional(None),
+        ))
+    source_args = gather_loopy_source_arguments((expansion,) + tuple(kernels))
+
+    all_strengths = sym.make_sym_vector("strength", nstrengths)
+    strengths = [all_strengths[i] for i in strength_usage]
+    coeffs = expansion.coefficients_from_source_vec(kernels,
+        avec, None, rscale, strengths, sac=sac)
+
+    coeff_names = [
+        sac.add_assignment(f"coeffs{i}", coeff) for i, coeff in enumerate(coeffs)
+    ]
+
+    sac.run_global_cse()
+
+    code_transformers = [expansion.get_code_transformer()] \
+        + [kernel.get_code_transformer() for kernel in kernels]
+
+    from sumpy.codegen import to_loopy_insns
+    insns += to_loopy_insns(
+            sac.assignments.items(),
+            vector_names={"a", "strength"},
+            pymbolic_expr_maps=code_transformers,
+            retain_names=coeff_names,
+            complex_dtype=np.complex128  # FIXME
+            )
+
+    coeffs = pymbolic.var("coeffs")
+
+    # change coeff{i} = expr into coeff[i] += expr
+    for i in range(len(insns)):
+        insn = insns[i]
+        if isinstance(insn, lp.Assignment) and \
+                isinstance(insn.assignee, pymbolic.var) and \
+                insn.assignee.name.startswith(coeffs.name):
+            idx = int(insn.assignee.name[len(coeffs.name):])
+            insns[i] = lp.Assignment(
+                assignee=coeffs[idx],
+                expression=coeffs[idx] + insn.expression,
+                id=f"coeff_{idx}",
+                depends_on=insn.depends_on,
+            )
+
+    loopy_knl = lp.make_function(domains, insns,
+            kernel_data=[
+                lp.GlobalArg("coeffs",
+                    shape=(ncoeffs,), is_input=True, is_output=True),
+                lp.GlobalArg("center, source",
+                    shape=(dim,), is_input=True, is_output=False),
+                lp.GlobalArg("strength",
+                    shape=(nstrengths,), is_input=True, is_output=False),
+                lp.ValueArg("rscale", is_input=True),
+                lp.ValueArg("isrc", is_input=True),
+                lp.ValueArg("nsources", is_input=True),
+                lp.GlobalArg("sources",
+                    shape=(dim, "nsources"), is_input=True, is_output=False),
+                *source_args,
+                ...],
+            name="p2e",
+            lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+            fixed_parameters={"dim": dim},
+            )
+
+    loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
+    for kernel in kernels:
+        loopy_knl = kernel.prepare_loopy_kernel(loopy_knl)
+
+    return loopy_knl
diff --git a/sumpy/kernel.py b/sumpy/kernel.py
index e7d7175b..1be1dec9 100644
--- a/sumpy/kernel.py
+++ b/sumpy/kernel.py
@@ -1171,17 +1171,22 @@ class DirectionalTargetDerivative(DirectionalDerivative):
         return DifferentiatedExprDerivativeTaker(expr.taker,
                 dict(new_transformation))
 
-    def get_source_args(self):
+    def get_args(self):
         return [
-                KernelArgument(
-                    loopy_arg=lp.GlobalArg(
-                        self.dir_vec_name,
-                        None,
-                        shape=(self.dim, "ntargets"),
-                        dim_tags="sep,C",
-                        offset=lp.auto),
-                    )
-                    ] + self.inner_kernel.get_source_args()
+            KernelArgument(
+                loopy_arg=lp.GlobalArg(
+                    self.dir_vec_name,
+                    None,
+                    shape=(self.dim, "ntargets"),
+                    offset=lp.auto
+                ),
+            ),
+            *self.inner_kernel.get_args()
+        ]
+
+    def prepare_loopy_kernel(self, loopy_knl):
+        loopy_knl = self.inner_kernel.prepare_loopy_kernel(loopy_knl)
+        return lp.tag_array_axes(loopy_knl, self.dir_vec_name, "sep,C")
 
     mapper_method = "map_directional_target_derivative"
 
@@ -1224,11 +1229,14 @@ class DirectionalSourceDerivative(DirectionalDerivative):
                         self.dir_vec_name,
                         None,
                         shape=(self.dim, "nsources"),
-                        dim_tags="sep,C",
                         offset=lp.auto),
                     )
                     ] + self.inner_kernel.get_source_args()
 
+    def prepare_loopy_kernel(self, loopy_knl):
+        loopy_knl = self.inner_kernel.prepare_loopy_kernel(loopy_knl)
+        return lp.tag_array_axes(loopy_knl, self.dir_vec_name, "sep,C")
+
     mapper_method = "map_directional_source_derivative"
 
 
diff --git a/sumpy/p2e.py b/sumpy/p2e.py
index c41dbb48..8c25415a 100644
--- a/sumpy/p2e.py
+++ b/sumpy/p2e.py
@@ -86,37 +86,22 @@ class P2EBase(KernelCacheMixin, KernelComputation):
         self.expansion = expansion
         self.dim = expansion.dim
 
-    def get_loopy_instructions(self):
-        from sumpy.symbolic import make_sym_vector
-        avec = make_sym_vector("a", self.dim)
-
-        import sumpy.symbolic as sp
-        rscale = sp.Symbol("rscale")
-
-        from sumpy.assignment_collection import SymbolicAssignmentCollection
-        sac = SymbolicAssignmentCollection()
-
-        strengths = [sp.Symbol(f"strength_{i}") for i in self.strength_usage]
-        coeffs = self.expansion.coefficients_from_source_vec(self.source_kernels,
-                    avec, None, rscale, strengths, sac=sac)
-
-        coeff_names = []
-        for i, coeff in enumerate(coeffs):
-            sac.add_assignment(f"coeff{i}", coeff)
-            coeff_names.append(f"coeff{i}")
-
-        sac.run_global_cse()
-
-        code_transformers = [self.expansion.get_code_transformer()] \
-            + [kernel.get_code_transformer() for kernel in self.source_kernels]
+    def add_loopy_form_callable(
+            self, loopy_knl: lp.TranslationUnit) -> lp.TranslationUnit:
+        inner_knl = self.expansion.get_loopy_expansion_formation(
+            self.source_kernels, self.strength_usage, self.strength_count)
+        loopy_knl = lp.merge([loopy_knl, inner_knl])
+        loopy_knl = lp.inline_callable_kernel(loopy_knl, "p2e")
+        loopy_knl = lp.remove_unused_inames(loopy_knl)
+        for kernel in self.source_kernels:
+            loopy_knl = kernel.prepare_loopy_kernel(loopy_knl)
+        loopy_knl = lp.tag_array_axes(loopy_knl, "strengths", "sep,C")
+        return loopy_knl
 
-        from sumpy.codegen import to_loopy_insns
-        return to_loopy_insns(
-                sac.assignments.items(),
-                vector_names={"a"},
-                pymbolic_expr_maps=code_transformers,
-                retain_names=coeff_names,
-                )
+    def get_loopy_args(self):
+        from sumpy.tools import gather_loopy_source_arguments
+        return gather_loopy_source_arguments(
+                (self.expansion,) + tuple(self.source_kernels))
 
     def get_cache_key(self):
         return (type(self).__name__, self.name, self.expansion,
@@ -166,11 +151,14 @@ class P2EFromSingleBox(P2EBase):
 
     def get_kernel(self):
         ncoeffs = len(self.expansion)
+        loopy_args = self.get_loopy_args()
 
-        from sumpy.tools import gather_loopy_source_arguments
         loopy_knl = lp.make_kernel([
                 "{[isrc_box]: 0 <= isrc_box < nsrc_boxes}",
-                "{[isrc, idim]: isrc_start <= isrc < isrc_end and 0 <= idim < dim}",
+                "{[isrc]: isrc_start <= isrc < isrc_end}",
+                "{[idim]: 0 <= idim < dim}",
+                "{[icoeff]: 0 <= icoeff < ncoeffs}",
+                "{[istrength]: 0 <= istrength < nstrengths}",
                 ], ["""
                 for isrc_box
                     <> src_ibox = source_boxes[isrc_box]
@@ -179,25 +167,35 @@ class P2EFromSingleBox(P2EBase):
 
                     <> center[idim] = centers[idim, src_ibox] {id=fetch_center}
 
+                    <> coeffs[icoeff] = 0  {id=init_coeffs,dup=icoeff}
                     for isrc
-                        <> a[idim] = center[idim] - sources[idim, isrc] {dup=idim}
-                        """] + [
-                        f"<> strength_{i} = strengths[{i}, isrc]"
-                        for i in set(self.strength_usage)
-                        ] + self.get_loopy_instructions() + ["""
+                        <> source[idim] = sources[idim, isrc] \
+                                {dup=idim,id=fetch_src}
+                        <> strength[istrength] = strengths[istrength, isrc] \
+                                {dup=istrength,id=fetch_strength}
+                        [icoeff]: coeffs[icoeff] = p2e(
+                                [icoeff]: coeffs[icoeff],
+                                [idim]: center[idim],
+                                [idim]: source[idim],
+                                [istrength]: strength[istrength],
+                                rscale,
+                                isrc,
+                                nsources,
+                                sources,
+                """ + ",".join(arg.name for arg in loopy_args) + """
+                            )  {id=update_result, \
+                              dep=fetch_center:fetch_src:init_coeffs}
                     end
-                    """] + [f"""
-                    tgt_expansions[src_ibox - tgt_base_ibox, {coeffidx}] = \
-                        simul_reduce(sum, isrc, coeff{coeffidx}) \
-                            {{id_prefix=write_expn}}
-                    """ for coeffidx in range(ncoeffs)] + ["""
+                    tgt_expansions[src_ibox - tgt_base_ibox, icoeff] = \
+                        coeffs[icoeff] {id=write_expn,dup=icoeff,\
+                        dep=update_result:init_coeffs}
                 end
                 """],
                 [
                     lp.GlobalArg("sources", None,
                         shape=(self.dim, "nsources"), order="C"),
                     lp.GlobalArg("strengths", None,
-                        shape=("strength_count", "nsources"), dim_tags="sep,C"),
+                        shape=(self.strength_count, "nsources")),
                     lp.GlobalArg("box_source_starts, box_source_counts_nonchild",
                         None, shape=None),
                     lp.GlobalArg("centers", None, shape="dim, aligned_nboxes"),
@@ -206,20 +204,21 @@ class P2EFromSingleBox(P2EBase):
                         shape=("nboxes", ncoeffs), offset=lp.auto),
                     lp.ValueArg("nboxes, aligned_nboxes, tgt_base_ibox", np.int32),
                     lp.ValueArg("nsources", np.int32),
+                    *loopy_args,
                     ...
-                ] + gather_loopy_source_arguments(
-                    self.source_kernels + (self.expansion,)),
+                ],
                 name=self.name,
                 assumptions="nsrc_boxes>=1",
                 silenced_warnings="write_race(write_expn*)",
                 default_offset=lp.auto,
                 fixed_parameters={
-                    "dim": self.dim, "strength_count": self.strength_count},
+                    "dim": self.dim, "nstrengths": self.strength_count,
+                    "ncoeffs": ncoeffs},
                 lang_version=MOST_RECENT_LANGUAGE_VERSION)
 
-        for knl in self.source_kernels:
-            loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
         loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
+        loopy_knl = lp.tag_inames(loopy_knl, "istrength*:unr")
+        loopy_knl = self.add_loopy_form_callable(loopy_knl)
 
         return loopy_knl
 
@@ -270,14 +269,14 @@ class P2EFromCSR(P2EBase):
 
     def get_kernel(self):
         ncoeffs = len(self.expansion)
+        loopy_args = self.get_loopy_args()
 
-        from sumpy.tools import gather_loopy_source_arguments
         arguments = (
                 [
                     lp.GlobalArg("sources", None,
                         shape=(self.dim, "nsources"), order="C"),
                     lp.GlobalArg("strengths", None,
-                        shape=("strength_count", "nsources"), dim_tags="sep,C"),
+                        shape=(self.strength_count, "nsources")),
                     lp.GlobalArg("source_box_starts,source_box_lists",
                         None, shape=None, offset=lp.auto),
                     lp.GlobalArg("box_source_starts,box_source_counts_nonchild",
@@ -288,9 +287,9 @@ class P2EFromCSR(P2EBase):
                     lp.ValueArg("naligned_boxes,ntgt_level_boxes,tgt_base_ibox",
                         np.int32),
                     lp.ValueArg("nsources", np.int32),
+                    *loopy_args,
                     ...
-                ] + gather_loopy_source_arguments(
-                    self.source_kernels + (self.expansion,)))
+                ])
 
         loopy_knl = lp.make_kernel(
                 [
@@ -298,6 +297,8 @@ class P2EFromCSR(P2EBase):
                     "{[isrc_box]: isrc_box_start <= isrc_box < isrc_box_stop}",
                     "{[isrc]: isrc_start <= isrc < isrc_end}",
                     "{[idim]: 0 <= idim < dim}",
+                    "{[icoeff]: 0 <= icoeff < ncoeffs}",
+                    "{[istrength]: 0 <= istrength < nstrengths}",
                     ],
                 ["""
                 for itgt_box
@@ -307,6 +308,7 @@ class P2EFromCSR(P2EBase):
                     <> isrc_box_start = source_box_starts[itgt_box]
                     <> isrc_box_stop = source_box_starts[itgt_box + 1]
 
+                    <> coeffs[icoeff] = 0  {id=init_coeffs,dup=icoeff}
                     for isrc_box
                         <> src_ibox = source_box_lists[isrc_box]
                         <> isrc_start = box_source_starts[src_ibox]
@@ -314,19 +316,27 @@ class P2EFromCSR(P2EBase):
                                 + box_source_counts_nonchild[src_ibox]
 
                         for isrc
-                            <> a[idim] = center[idim] - sources[idim, isrc] \
-                                    {dup=idim}
-                    """] + [
-                        f"""
-                             <> strength_{i} = strengths[{i}, isrc]
-                             """ for i in set(self.strength_usage)
-                             ] + self.get_loopy_instructions() + ["""
+                            <> source[idim] = sources[idim, isrc] \
+                                    {dup=idim,id=fetch_src}
+                            <> strength[istrength] = strengths[istrength, isrc] \
+                                    {dup=istrength,id=fetch_strength}
+                            [icoeff]: coeffs[icoeff] = p2e(
+                                    [icoeff]: coeffs[icoeff],
+                                    [idim]: center[idim],
+                                    [idim]: source[idim],
+                                    [istrength]: strength[istrength],
+                                    rscale,
+                                    isrc,
+                                    nsources,
+                                    sources,
+                    """ + ",".join(arg.name for arg in loopy_args) + """
+                                )  {id=update_result, \
+                                  dep=fetch_center:fetch_src:init_coeffs}
                         end
-                    end"""] + [f"""
-                    tgt_expansions[tgt_ibox - tgt_base_ibox, {coeffidx}] = \
-                            simul_reduce(sum, (isrc_box, isrc), coeff{coeffidx}) \
-                            {{id_prefix=write_expn}}
-                    """ for coeffidx in range(ncoeffs)] + ["""
+                    end
+                    tgt_expansions[tgt_ibox - tgt_base_ibox, icoeff] = \
+                            coeffs[icoeff] {id=write_expn,dup=icoeff, \
+                            dep=update_result:init_coeffs}
                 end
                 """],
                 arguments,
@@ -335,12 +345,13 @@ class P2EFromCSR(P2EBase):
                 silenced_warnings="write_race(write_expn*)",
                 default_offset=lp.auto,
                 fixed_parameters={"dim": self.dim,
-                                  "strength_count": self.strength_count},
+                                  "nstrengths": self.strength_count,
+                                  "ncoeffs": ncoeffs},
                 lang_version=MOST_RECENT_LANGUAGE_VERSION)
 
-        for knl in self.source_kernels:
-            loopy_knl = knl.prepare_loopy_kernel(loopy_knl)
         loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
+        loopy_knl = lp.tag_inames(loopy_knl, "istrength*:unr")
+        loopy_knl = self.add_loopy_form_callable(loopy_knl)
 
         return loopy_knl
 
-- 
GitLab