From 82d6d431d32770ea02ade64112f272327c037ab5 Mon Sep 17 00:00:00 2001
From: Isuru Fernando <isuruf@gmail.com>
Date: Tue, 23 May 2023 17:37:45 -0500
Subject: [PATCH] Optimize preprocess multipole and postprocess local (#156)

---
 sumpy/e2e.py                | 35 ++++++++++----
 sumpy/expansion/__init__.py | 93 ++++++++++++++++++++++++++++++++++---
 sumpy/expansion/m2l.py      | 93 ++++++++++++++++++++++++++-----------
 test/test_misc.py           | 42 +++++++++++++++++
 4 files changed, 220 insertions(+), 43 deletions(-)

diff --git a/sumpy/e2e.py b/sumpy/e2e.py
index c357082b..7899c7b4 100644
--- a/sumpy/e2e.py
+++ b/sumpy/e2e.py
@@ -668,14 +668,20 @@ class M2LPreprocessMultipole(E2EBase):
     def default_name(self):
         return "m2l_preprocess_multipole"
 
+    @memoize_method
+    def get_inner_knl_and_optimizations(self, result_dtype):
+        m2l_translation = self.tgt_expansion.m2l_translation
+        return m2l_translation.preprocess_multipole_loopy_knl(
+            self.tgt_expansion, self.src_expansion, result_dtype)
+
     def get_kernel(self, result_dtype):
         m2l_translation = self.tgt_expansion.m2l_translation
         nsrc_coeffs = len(self.src_expansion)
         npreprocessed_src_coeffs = \
             m2l_translation.preprocess_multipole_nexprs(self.tgt_expansion,
                 self.src_expansion)
-        single_box_preprocess_knl = m2l_translation.preprocess_multipole_loopy_knl(
-            self.tgt_expansion, self.src_expansion, result_dtype)
+        single_box_preprocess_knl, _ = self.get_inner_knl_and_optimizations(
+                result_dtype)
 
         from sumpy.tools import gather_loopy_arguments
         loopy_knl = lp.make_kernel(
@@ -721,11 +727,11 @@ class M2LPreprocessMultipole(E2EBase):
         return loopy_knl
 
     def get_optimized_kernel(self, result_dtype):
-        # FIXME
         knl = self.get_kernel(result_dtype)
-        knl = lp.split_iname(knl, "isrc_box", 64, outer_tag="g.0",
-                             within=f"in_kernel:{self.name}")
-        knl = lp.add_inames_for_unused_hw_axes(knl)
+        knl = lp.tag_inames(knl, "isrc_box:g.0")
+        _, optimizations = self.get_inner_knl_and_optimizations(result_dtype)
+        for optimization in optimizations:
+            knl = optimization(knl)
         return knl
 
     def __call__(self, queue, **kwargs):
@@ -752,6 +758,12 @@ class M2LPostprocessLocal(E2EBase):
     def default_name(self):
         return "m2l_postprocess_local"
 
+    @memoize_method
+    def get_inner_knl_and_optimizations(self, result_dtype):
+        m2l_translation = self.tgt_expansion.m2l_translation
+        return m2l_translation.postprocess_local_loopy_knl(
+            self.tgt_expansion, self.src_expansion, result_dtype)
+
     def get_kernel(self, result_dtype):
         m2l_translation = self.tgt_expansion.m2l_translation
         ntgt_coeffs = len(self.tgt_expansion)
@@ -759,8 +771,8 @@ class M2LPostprocessLocal(E2EBase):
             m2l_translation.postprocess_local_nexprs(self.tgt_expansion,
                 self.src_expansion)
 
-        single_box_postprocess_knl = m2l_translation.postprocess_local_loopy_knl(
-            self.tgt_expansion, self.src_expansion, result_dtype)
+        single_box_postprocess_knl, _ = self.get_inner_knl_and_optimizations(
+                result_dtype)
 
         from sumpy.tools import gather_loopy_arguments
         loopy_knl = lp.make_kernel(
@@ -813,9 +825,12 @@ class M2LPostprocessLocal(E2EBase):
         return loopy_knl
 
     def get_optimized_kernel(self, result_dtype):
-        # FIXME
         knl = self.get_kernel(result_dtype)
-        knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0")
+        knl = lp.tag_inames(knl, "itgt_box:g.0")
+        _, optimizations = self.get_inner_knl_and_optimizations(result_dtype)
+        for optimization in optimizations:
+            knl = optimization(knl)
+        knl = lp.add_inames_for_unused_hw_axes(knl)
         return knl
 
     def __call__(self, queue, **kwargs):
diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py
index 3ee58917..c667ba99 100644
--- a/sumpy/expansion/__init__.py
+++ b/sumpy/expansion/__init__.py
@@ -30,6 +30,7 @@ import loopy as lp
 import sumpy.symbolic as sym
 from sumpy.kernel import Kernel
 from sumpy.tools import add_mi
+import pymbolic.primitives as prim
 
 import logging
 logger = logging.getLogger(__name__)
@@ -377,6 +378,18 @@ class ExpansionTermsWrangler(ABC):
 
 
 class FullExpansionTermsWrangler(ExpansionTermsWrangler):
+
+    def get_storage_index(self, mi, order=None):
+        if not order:
+            order = sum(mi)
+        if self.dim == 3:
+            return (order*(order + 1)*(order + 2))//6 + \
+                    (order + 2)*mi[2] - (mi[2]*(mi[2] + 1))//2 + mi[1]
+        elif self.dim == 2:
+            return (order*(order + 1))//2 + mi[1]
+        else:
+            raise NotImplementedError
+
     def get_coefficient_identifiers(self):
         return super().get_full_coefficient_identifiers()
 
@@ -389,6 +402,26 @@ class FullExpansionTermsWrangler(ExpansionTermsWrangler):
         return self.get_full_kernel_derivatives_from_stored(
             full_mpole_coefficients, rscale, sac=sac)
 
+    @memoize_method
+    def _get_mi_ordering_key_and_axis_permutation(self):
+        """
+        Returns a degree lexicographic order as a callable that can be used as a
+        ``sort`` key on multi-indices and a permutation of the axis ordered
+        from the slowest varying axis to the fastest varying axis of the
+        multi-indices when sorted.
+        """
+        from sumpy.expansion.diff_op import DerivativeIdentifier
+
+        axis_permutation = list(reversed(list(range(self.dim))))
+
+        def mi_key(ident):
+            if isinstance(ident, DerivativeIdentifier):
+                mi = ident.mi
+            else:
+                mi = ident
+            return tuple([sum(mi)] + list(reversed(mi)))
+
+        return mi_key, axis_permutation
 # }}}
 
 
@@ -520,7 +553,7 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
     # the axes so that the axis with the on-axis coefficient comes first in the
     # multi-index tuple.
     @memoize_method
-    def _get_mi_ordering_key(self):
+    def _get_mi_ordering_key_and_axis_permutation(self):
         """
         A degree lexicographic order with the slowest varying index depending on
         the PDE is used, returned as a callable that can be used as a
@@ -529,6 +562,9 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
         multipole-to-multipole translation to get lower error bounds.
         The slowest varying index is chosen such that the multipole-to-local
         translation cost is optimized.
+
+        Also returns a permutation of the axis ordered from the slowest varying
+        axis to the fastest varying axis of the multi-indices when sorted.
         """
         dim = self.dim
         deriv_id_to_coeff, = self.knl.get_pde_as_diff_op().eqs
@@ -554,7 +590,7 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
                 key.append(mi[axis_permutation[i]])
             return tuple(key)
 
-        return mi_key
+        return mi_key, axis_permutation
 
     def _get_mi_hyperpplanes(self) -> List[Tuple[int, int]]:
         mis = self.get_full_coefficient_identifiers()
@@ -570,8 +606,8 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
         else:
             # Calculate the multi-index that appears last in in the PDE in
             # the degree lexicographic order given by
-            # _get_mi_ordering_key.
-            ordering_key = self._get_mi_ordering_key()
+            # _get_mi_ordering_key_and_axis_permutation.
+            ordering_key, _ = self._get_mi_ordering_key_and_axis_permutation()
             max_mi = max(deriv_id_to_coeff, key=ordering_key).mi
             hyperplanes = [(d, const)
                 for d in range(self.dim)
@@ -581,9 +617,54 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
 
     def get_full_coefficient_identifiers(self):
         identifiers = super().get_full_coefficient_identifiers()
-        key = self._get_mi_ordering_key()
+        key, _ = self._get_mi_ordering_key_and_axis_permutation()
         return sorted(identifiers, key=key)
 
+    def get_storage_index(self, mi, order=None):
+        if not order:
+            order = sum(mi)
+
+        ordering_key, axis_permutation = \
+                self._get_mi_ordering_key_and_axis_permutation()
+        deriv_id_to_coeff, = self.knl.get_pde_as_diff_op().eqs
+        max_mi = max(deriv_id_to_coeff, key=ordering_key).mi
+
+        if all(m != 0 for m in max_mi):
+            raise NotImplementedError("non-elliptic PDEs")
+
+        c = max_mi[axis_permutation[0]]
+
+        mi = list(mi)
+        mi[axis_permutation[0]], mi[0] = mi[0], mi[axis_permutation[0]]
+
+        if self.dim == 3:
+            if all(isinstance(axis, int) for axis in mi):
+                if order < c - 1:
+                    return (order*(order + 1)*(order + 2))//6 + \
+                        (order + 2)*mi[0] - (mi[0]*(mi[0] + 1))//2 + mi[1]
+                else:
+                    return (c*(c-1)*(c-2))//6 + (c * order * (2 + order - c)
+                        + mi[0]*(3 - mi[0]+2*order))//2 + mi[1]
+            else:
+                return prim.If(prim.Comparison(order, "<", c - 1),
+                    (order*(order + 1)*(order + 2))//6
+                        + (order + 2)*mi[0] - (mi[0]*(mi[0] + 1))//2 + mi[1],
+                    (c*(c-1)*(c-2))//6 + (c * order * (2 + order - c)
+                        + mi[0]*(3 - mi[0]+2*order))//2 + mi[1]
+                )
+        elif self.dim == 2:
+            if all(isinstance(axis, int) for axis in mi):
+                if order < c - 1:
+                    return (order*(order + 1))//2 + mi[0]
+                else:
+                    return (c*(c-1))//2 + c*(order - c + 1) + mi[0]
+            else:
+                return prim.If(prim.Comparison(order, "<", c - 1),
+                    (order*(order + 1))//2 + mi[0],
+                    (c*(c-1))//2 + c*(order - c + 1) + mi[0])
+        else:
+            raise NotImplementedError
+
     @memoize_method
     def get_stored_ids_and_unscaled_projection_matrix(self):
         from pytools import ProcessLogger
@@ -608,7 +689,7 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
                                        from_output_coeffs_by_row, shape)
                 return mis, op
 
-        ordering_key = self._get_mi_ordering_key()
+        ordering_key, _ = self._get_mi_ordering_key_and_axis_permutation()
         max_mi = max((ident for ident in mi_to_coeff.keys()), key=ordering_key)
         max_mi_coeff = mi_to_coeff[max_mi]
         max_mi_mult = -1/sym.sympify(max_mi_coeff)
diff --git a/sumpy/expansion/m2l.py b/sumpy/expansion/m2l.py
index dbcc7d26..a56d91b5 100644
--- a/sumpy/expansion/m2l.py
+++ b/sumpy/expansion/m2l.py
@@ -440,64 +440,98 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase):
     def preprocess_multipole_loopy_knl(self, tgt_expansion, src_expansion,
             result_dtype):
 
-        circulant_matrix_mis, _, _ = \
+        circulant_matrix_mis, _, max_mi = \
             self._translation_classes_dependent_data_mis(tgt_expansion,
                 src_expansion)
-        circulant_matrix_ident_to_index = {
-                ident: i for i, ident in enumerate(circulant_matrix_mis)}
 
         ncoeff_src = len(src_expansion.get_coefficient_identifiers())
         ncoeff_preprocessed = self.preprocess_multipole_nexprs(tgt_expansion,
             src_expansion)
+        order = src_expansion.order
 
         output_coeffs = pymbolic.var("output_coeffs")
         input_coeffs = pymbolic.var("input_coeffs")
-        srcidx_sym = pymbolic.var("srcidx")
         output_icoeff = pymbolic.var("output_icoeff")
         input_icoeff = pymbolic.var("input_icoeff")
+        input_coeffs_copy = pymbolic.var("input_coeffs_copy")
+
+        dim = tgt_expansion.dim
+        v = [pymbolic.var(f"x{i}") for i in range(dim)]
+
+        wrangler = src_expansion.expansion_terms_wrangler
+        _, axis_permutation = wrangler._get_mi_ordering_key_and_axis_permutation()
+        slowest_idx = axis_permutation[0]
+        # max_mi[slowest_idx] = 2*(c - 1)
+        c = max_mi[slowest_idx] // 2 + 1
+        noutput_coeffs = c * (2*order + 1) ** (dim - 1)
 
         domains = [
             "{[output_icoeff]: 0<=output_icoeff<noutput_coeffs}",
+            "{[input_icoeff]: 0<=input_icoeff<ninput_coeffs}",
         ]
+
         insns = [
             lp.Assignment(
-                assignee=input_icoeff,
-                expression=srcidx_sym[output_icoeff],
-                id="input_icoeff",
+                assignee=input_coeffs_copy[input_icoeff],
+                expression=input_coeffs[input_icoeff],
+                id="input_copy",
+                temp_var_type=lp.Optional(None),
             ),
+        ]
+
+        idx = output_icoeff
+        for i in range(dim - 1, -1, -1):
+            new_idx = idx % (max_mi[i] + 1) if i > 0 else idx
+            insns.append(lp.Assignment(
+                    assignee=v[i],
+                    expression=new_idx,
+                    id=f"set_x{i}",
+                    temp_var_type=lp.Optional(None),
+            ))
+            idx = idx // (max_mi[i] + 1)
+
+        input_idx = wrangler.get_storage_index(v)
+        output_idx = 0
+        mult = 1
+        for i in range(dim - 1, -1, -1):
+            output_idx += mult*v[i]
+            mult *= (max_mi[i] + 1)
+
+        insns += [
             lp.Assignment(
                 assignee=output_coeffs[output_icoeff],
-                expression=pymbolic.primitives.If(
-                    pymbolic.primitives.Comparison(input_icoeff, ">=", 0),
-                    input_coeffs[input_icoeff],
-                    0,
-                ),
-                depends_on=frozenset(["input_icoeff"]),
+                expression=input_coeffs_copy[input_idx],
+                predicates=frozenset([
+                    pymbolic.primitives.Comparison(sum(v), "<=", order),
+                    pymbolic.primitives.Comparison(v[slowest_idx], "<", c),
+                ]),
+                depends_on=frozenset([f"set_x{i}" for i in range(dim)]
+                    + ["input_copy"]),
             )
         ]
 
-        srcidx = np.full(ncoeff_preprocessed, -1, dtype=np.int32)
-        for icoeff_src, term in enumerate(
-                src_expansion.get_coefficient_identifiers()):
-            new_icoeff_src = circulant_matrix_ident_to_index[term]
-            srcidx[new_icoeff_src] = icoeff_src
-
-        return lp.make_function(domains, insns,
+        knl = lp.make_function(domains, insns,
             kernel_data=[
                 lp.ValueArg("src_rscale", None),
                 lp.GlobalArg("output_coeffs", None, shape=ncoeff_preprocessed,
                     is_input=False, is_output=True),
                 lp.GlobalArg("input_coeffs", None, shape=ncoeff_src),
-                lp.TemporaryVariable(input_icoeff.name, dtype=np.int32),
-                lp.TemporaryVariable(
-                    srcidx_sym.name, initializer=srcidx,
-                    address_space=lp.AddressSpace.GLOBAL, read_only=True),
                 ...],
             name="m2l_preprocess_inner",
             lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
-            fixed_parameters={"noutput_coeffs": ncoeff_preprocessed},
+            fixed_parameters={"noutput_coeffs": noutput_coeffs,
+                              "ninput_coeffs": ncoeff_src},
         )
 
+        optimizations = [
+            lambda knl: lp.split_iname(knl, "m2l__input_icoeff",
+                32, inner_tag="l.0"),
+            lambda knl: lp.split_iname(knl, "m2l__output_icoeff",
+                32, inner_tag="l.0"),
+        ]
+
+        return (knl, optimizations)
+
     def postprocess_local_exprs(self, tgt_expansion, src_expansion, m2l_result,
             src_rscale, tgt_rscale, sac):
         circulant_matrix_mis, _, _ = \
@@ -607,7 +641,12 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase):
             "{[output_icoeff]: 0<=output_icoeff<ncoeff_tgt}"
         ]
 
-        return lp.make_function(domains, insns,
+        optimizations = [
+            lambda knl: lp.split_iname(knl, "m2l__output_icoeff",
+                32, inner_tag="l.0")
+        ]
+
+        return (lp.make_function(domains, insns,
             kernel_data=[
                 lp.ValueArg("src_rscale", None),
                 lp.ValueArg("tgt_rscale", None),
@@ -630,7 +669,7 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase):
             name="m2l_postprocess_inner",
             lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
             fixed_parameters=fixed_parameters,
-        )
+        ), optimizations)
 
 # }}} VolumeTaylorM2LTranslation
 
diff --git a/test/test_misc.py b/test/test_misc.py
index 5032e2fc..9b725781 100644
--- a/test/test_misc.py
+++ b/test/test_misc.py
@@ -49,6 +49,9 @@ from sumpy.expansion.diff_op import (
     make_identity_diff_op, concat, as_scalar_pde, diff,
     gradient, divergence, laplacian, curl)
 
+from sumpy.expansion import (FullExpansionTermsWrangler,
+    LinearPDEBasedExpansionTermsWrangler)
+
 import logging
 logger = logging.getLogger(__name__)
 
@@ -516,6 +519,45 @@ def test_weird_kernel(pde):
 # }}}
 
 
+# {{{ test_get_storage_index
+
+class TestKernel(ExpressionKernel):
+    def __init__(self, dim, max_mi):
+        super().__init__(dim=dim, expression=1, global_scaling_const=1,
+                is_complex_valued=False)
+        self._max_mi = max_mi
+
+    def get_pde_as_diff_op(self):
+        w = make_identity_diff_op(self.dim)
+        pde = diff(w, tuple(self._max_mi))
+        return pde
+
+
+@pytest.mark.parametrize("order", [6])
+@pytest.mark.parametrize("knl", [
+    LaplaceKernel(2),
+    LaplaceKernel(3),
+    TestKernel(2, (3, 0)),
+    TestKernel(2, (0, 3)),
+    TestKernel(3, (3, 0, 0)),
+    TestKernel(3, (0, 3, 0)),
+    TestKernel(3, (0, 0, 3)),
+    BiharmonicKernel(2),
+    BiharmonicKernel(3),
+])
+@pytest.mark.parametrize("compressed", (True, False))
+def test_get_storage_index(order, knl, compressed):
+    dim = knl.dim
+    if compressed:
+        wrangler = LinearPDEBasedExpansionTermsWrangler(order, dim, knl=knl)
+    else:
+        wrangler = FullExpansionTermsWrangler(order, dim)
+    for i, mi in enumerate(wrangler.get_coefficient_identifiers()):
+        assert i == wrangler.get_storage_index(mi)
+
+# }}}
+
+
 # You can test individual routines by typing
 # $ python test_misc.py 'test_pde_check_kernels(_acf,
 #       KernelInfo(HelmholtzKernel(2), k=5), order=5)'
-- 
GitLab