From 6e8e2e7c90c794e84d57304d58ac977e2f5e0528 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 19 Apr 2022 00:56:09 -0500
Subject: [PATCH] Fix, test order varying with level

---
 sumpy/e2e.py     | 21 ++++++-------------
 test/test_fmm.py | 52 ++++++++++++++++++++++++++++++------------------
 2 files changed, 39 insertions(+), 34 deletions(-)

diff --git a/sumpy/e2e.py b/sumpy/e2e.py
index eba1e581..6f742f72 100644
--- a/sumpy/e2e.py
+++ b/sumpy/e2e.py
@@ -862,11 +862,6 @@ class E2EFromChildren(E2EBase):
     default_name = "e2e_from_children"
 
     def get_kernel(self):
-        if self.src_expansion is not self.tgt_expansion:
-            raise RuntimeError(
-                f"{type(self).__name__} requires that the source "
-                "and target expansion are the same object")
-
         ncoeffs_src = len(self.src_expansion)
         ncoeffs_tgt = len(self.tgt_expansion)
 
@@ -984,12 +979,8 @@ class E2EFromParent(E2EBase):
     default_name = "e2e_from_parent"
 
     def get_kernel(self):
-        if self.src_expansion is not self.tgt_expansion:
-            raise RuntimeError(
-                f"{self.default_name} requires that the source "
-                "and target expansion are the same object")
-
-        ncoeffs = len(self.src_expansion)
+        ncoeffs_src = len(self.src_expansion)
+        ncoeffs_tgt = len(self.tgt_expansion)
 
         # To clarify terminology:
         #
@@ -1020,14 +1011,14 @@ class E2EFromParent(E2EBase):
                     <> src_coeff{i} = \
                         src_expansions[src_ibox - src_base_ibox, {i}] \
                         {{id_prefix=read_expn,dep=read_src_ibox}}
-                    """.format(i=i) for i in range(ncoeffs)] + [
+                    """.format(i=i) for i in range(ncoeffs_src)] + [
 
                     ] + self.get_translation_loopy_insns() + ["""
 
                     tgt_expansions[tgt_ibox - tgt_base_ibox, {i}] = \
                         tgt_expansions[tgt_ibox - tgt_base_ibox, {i}] + coeff{i} \
                         {{id_prefix=write_expn,nosync=read_expn*}}
-                    """.format(i=i) for i in range(ncoeffs)] + ["""
+                    """.format(i=i) for i in range(ncoeffs_tgt)] + ["""
                 end
                 """],
                 [
@@ -1040,9 +1031,9 @@ class E2EFromParent(E2EBase):
                     lp.ValueArg("ntgt_level_boxes,nsrc_level_boxes", np.int32),
                     lp.GlobalArg("box_parent_ids", None, shape="nboxes"),
                     lp.GlobalArg("tgt_expansions", None,
-                        shape=("ntgt_level_boxes", ncoeffs), offset=lp.auto),
+                        shape=("ntgt_level_boxes", ncoeffs_tgt), offset=lp.auto),
                     lp.GlobalArg("src_expansions", None,
-                        shape=("nsrc_level_boxes", ncoeffs), offset=lp.auto),
+                        shape=("nsrc_level_boxes", ncoeffs_src), offset=lp.auto),
                     "..."
                 ] + gather_loopy_arguments([self.src_expansion, self.tgt_expansion]),
                 name=self.name, assumptions="ntgt_boxes>=1",
diff --git a/test/test_fmm.py b/test/test_fmm.py
index 5e505c4b..3c5ec6d7 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -59,25 +59,32 @@ else:
 
 @pytest.mark.parametrize("optimized_m2l, use_fft",
     [(False, False), (True, False), (True, True)])
-@pytest.mark.parametrize("knl, local_expn_class, mpole_expn_class",
-[
-    (LaplaceKernel(2), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion),
-    (LaplaceKernel(2), LinearPDEConformingVolumeTaylorLocalExpansion,
-        LinearPDEConformingVolumeTaylorMultipoleExpansion),
-    (LaplaceKernel(3), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion),
-    (LaplaceKernel(3), LinearPDEConformingVolumeTaylorLocalExpansion,
-        LinearPDEConformingVolumeTaylorMultipoleExpansion),
-    (HelmholtzKernel(2), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion),
-    (HelmholtzKernel(2), LinearPDEConformingVolumeTaylorLocalExpansion,
-        LinearPDEConformingVolumeTaylorMultipoleExpansion),
-    (HelmholtzKernel(2), H2DLocalExpansion, H2DMultipoleExpansion),
-    (HelmholtzKernel(3), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion),
-    (HelmholtzKernel(3), LinearPDEConformingVolumeTaylorLocalExpansion,
-        LinearPDEConformingVolumeTaylorMultipoleExpansion),
-    (YukawaKernel(2), Y2DLocalExpansion, Y2DMultipoleExpansion),
-])
+@pytest.mark.parametrize(
+        ("knl", "local_expn_class", "mpole_expn_class",
+        "order_varies_with_level"), [
+            (LaplaceKernel(2), VolumeTaylorLocalExpansion,
+                VolumeTaylorMultipoleExpansion, False),
+            (LaplaceKernel(2), LinearPDEConformingVolumeTaylorLocalExpansion,
+                LinearPDEConformingVolumeTaylorMultipoleExpansion, False),
+            (LaplaceKernel(3), VolumeTaylorLocalExpansion,
+                VolumeTaylorMultipoleExpansion, False),
+            (LaplaceKernel(3), LinearPDEConformingVolumeTaylorLocalExpansion,
+                LinearPDEConformingVolumeTaylorMultipoleExpansion, False),
+            (HelmholtzKernel(2), VolumeTaylorLocalExpansion,
+                VolumeTaylorMultipoleExpansion, False),
+            (HelmholtzKernel(2), LinearPDEConformingVolumeTaylorLocalExpansion,
+                LinearPDEConformingVolumeTaylorMultipoleExpansion, False),
+            (HelmholtzKernel(2), H2DLocalExpansion, H2DMultipoleExpansion, False),
+            (HelmholtzKernel(2), H2DLocalExpansion, H2DMultipoleExpansion, True),
+            (HelmholtzKernel(3), VolumeTaylorLocalExpansion,
+                VolumeTaylorMultipoleExpansion, False),
+            (HelmholtzKernel(3), LinearPDEConformingVolumeTaylorLocalExpansion,
+                LinearPDEConformingVolumeTaylorMultipoleExpansion, False),
+            (YukawaKernel(2), Y2DLocalExpansion, Y2DMultipoleExpansion,
+                False),
+            ])
 def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
-        optimized_m2l, use_fft):
+        order_varies_with_level, optimized_m2l, use_fft):
     logging.basicConfig(level=logging.INFO)
 
     if local_expn_class == VolumeTaylorLocalExpansion and use_fft:
@@ -192,12 +199,19 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
                 partial(local_expn_class, knl),
                 target_kernels, use_fft_for_m2l=use_fft)
 
+        if order_varies_with_level:
+            def fmm_level_to_order(kernel, kernel_args, tree, lev):
+                return order + lev % 2
+        else:
+            def fmm_level_to_order(kernel, kernel_args, tree, lev):
+                return order
+
         with warnings.catch_warnings():
             if not optimized_m2l:
                 warnings.simplefilter("ignore",
                     SumpyTranslationClassesDataNotSuppliedWarning)
             wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
-                fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order,
+                fmm_level_to_order=fmm_level_to_order,
                 kernel_extra_kwargs=extra_kwargs,
                 translation_classes_data=translation_classes_data)
 
-- 
GitLab