From d0e02f685ef6ef9297c4a19aec5d72ca1793467a Mon Sep 17 00:00:00 2001
From: Isuru Fernando <idf2@illinois.edu>
Date: Sun, 9 May 2021 14:40:41 -0500
Subject: [PATCH] Move get_pde_as_diff_op to kernel and simplify code

---
 sumpy/expansion/__init__.py  | 128 +++++++----------------------------
 sumpy/expansion/local.py     |  34 +++-------
 sumpy/expansion/multipole.py |  34 +++-------
 sumpy/kernel.py              |  42 ++++++++++++
 4 files changed, 86 insertions(+), 152 deletions(-)

diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py
index 79bc9cb8..df39c766 100644
--- a/sumpy/expansion/__init__.py
+++ b/sumpy/expansion/__init__.py
@@ -24,7 +24,6 @@ import logging
 from pytools import memoize_method
 import sumpy.symbolic as sym
 from sumpy.tools import add_mi
-from .diff_op import make_identity_diff_op, laplacian
 from typing import List, Tuple
 
 __doc__ = """
@@ -275,7 +274,7 @@ class ExpansionTermsWrangler:
         # to which it is orthogonal to and the constant `c` described above
         hyperplanes = []
         if isinstance(self, LinearPDEBasedExpansionTermsWrangler):
-            pde_dict, = self.get_pde_as_diff_op().eqs
+            pde_dict, = self.knl.get_pde_as_diff_op().eqs
 
             if not all(ident.mi in mi_to_index for ident in pde_dict):
                 # The order of the expansion is less than the order of the PDE.
@@ -406,18 +405,18 @@ class CSEMatVecOperator:
 class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
     """
     .. automethod:: __init__
-    .. automethod:: get_pde_as_diff_op
     """
 
     init_arg_names = ("order", "dim", "max_mi")
 
-    def __init__(self, order, dim, max_mi=None):
+    def __init__(self, order, dim, knl, max_mi=None):
         r"""
         :param order: order of the expansion
         :param dim: number of dimensions
+        :param knl: kernel for the PDE
         """
-        super().__init__(order, dim,
-                max_mi)
+        super().__init__(order, dim, max_mi)
+        self.knl = knl
 
     def get_coefficient_identifiers(self):
         return self.stored_identifiers
@@ -443,14 +442,6 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
         stored_identifiers, _ = self.get_stored_ids_and_unscaled_projection_matrix()
         return stored_identifiers
 
-    def get_pde_as_diff_op(self):
-        r"""
-        Returns the PDE as a :class:`sumpy.expansion.diff_op.LinearPDESystemOperator`
-        object `L` where `L(u) = 0` is the PDE.
-        """
-
-        raise NotImplementedError
-
     @memoize_method
     def get_stored_ids_and_unscaled_projection_matrix(self):
         from pytools import ProcessLogger
@@ -460,7 +451,7 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
         coeff_ident_enumerate_dict = {tuple(mi): i for
                                             (i, mi) in enumerate(mis)}
 
-        diff_op = self.get_pde_as_diff_op()
+        diff_op = self.knl.get_pde_as_diff_op()
         assert len(diff_op.eqs) == 1
         pde_dict = {k.mi: v for k, v in diff_op.eqs[0].items()}
         for ident in pde_dict.keys():
@@ -582,46 +573,6 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
                                  projection_with_rscale, shape)
 
 
-class LaplaceExpansionTermsWrangler(LinearPDEBasedExpansionTermsWrangler):
-
-    init_arg_names = ("order", "dim", "max_mi")
-
-    def __init__(self, order, dim, max_mi=None):
-        super().__init__(order=order, dim=dim,
-            max_mi=max_mi)
-
-    def get_pde_as_diff_op(self):
-        w = make_identity_diff_op(self.dim)
-        return laplacian(w)
-
-
-class HelmholtzExpansionTermsWrangler(LinearPDEBasedExpansionTermsWrangler):
-
-    init_arg_names = ("order", "dim", "helmholtz_k_name", "max_mi")
-
-    def __init__(self, order, dim, helmholtz_k_name, max_mi=None):
-        self.helmholtz_k_name = helmholtz_k_name
-        super().__init__(order=order, dim=dim,
-            max_mi=max_mi)
-
-    def get_pde_as_diff_op(self, **kwargs):
-        w = make_identity_diff_op(self.dim)
-        k = sym.Symbol(self.helmholtz_k_name)
-        return (laplacian(w) + k**2 * w)
-
-
-class BiharmonicExpansionTermsWrangler(LinearPDEBasedExpansionTermsWrangler):
-
-    init_arg_names = ("order", "dim", "max_mi")
-
-    def __init__(self, order, dim, max_mi=None):
-        super().__init__(order=order, dim=dim,
-            max_mi=max_mi)
-
-    def get_pde_as_diff_op(self, **kwargs):
-        w = make_identity_diff_op(self.dim)
-        return laplacian(laplacian(w))
-
 # }}}
 
 
@@ -678,36 +629,19 @@ class VolumeTaylorExpansion(VolumeTaylorExpansionBase):
         self.expansion_terms_wrangler_key = (order, kernel.dim)
 
 
-class LaplaceConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase):
-
-    expansion_terms_wrangler_class = LaplaceExpansionTermsWrangler
-    expansion_terms_wrangler_cache = {}
-
-    # not user-facing, be strict about having to pass use_rscale
-    def __init__(self, kernel, order, use_rscale):
-        self.expansion_terms_wrangler_key = (order, kernel.dim)
-
-
-class HelmholtzConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase):
+class LinearPDEConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase):
 
-    expansion_terms_wrangler_class = HelmholtzExpansionTermsWrangler
+    expansion_terms_wrangler_class = LinearPDEBasedExpansionTermsWrangler
     expansion_terms_wrangler_cache = {}
 
     # not user-facing, be strict about having to pass use_rscale
     def __init__(self, kernel, order, use_rscale):
-        helmholtz_k_name = kernel.get_base_kernel().helmholtz_k_name
-        self.expansion_terms_wrangler_key = (order, kernel.dim, helmholtz_k_name)
-
-
-class BiharmonicConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase):
-
-    expansion_terms_wrangler_class = BiharmonicExpansionTermsWrangler
-    expansion_terms_wrangler_cache = {}
+        self.expansion_terms_wrangler_key = (order, kernel.dim, kernel)
 
-    # not user-facing, be strict about having to pass use_rscale
-    def __init__(self, kernel, order, use_rscale):
-        self.expansion_terms_wrangler_key = (order, kernel.dim)
 
+LaplaceConformingVolumeTaylorExpansion = LinearPDEConformingVolumeTaylorExpansion
+HelmholtzConformingVolumeTaylorExpansion = LinearPDEConformingVolumeTaylorExpansion
+BiharmonicConformingVolumeTaylorExpansion = LinearPDEConformingVolumeTaylorExpansion
 
 # }}}
 
@@ -757,13 +691,10 @@ class DefaultExpansionFactory(ExpansionFactoryBase):
     def get_local_expansion_class(self, base_kernel):
         """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
-        from sumpy.kernel import (HelmholtzKernel, LaplaceKernel, YukawaKernel,
-                BiharmonicKernel, StokesletKernel, StressletKernel)
+        from sumpy.kernel import (HelmholtzKernel, YukawaKernel)
 
         from sumpy.expansion.local import (H2DLocalExpansion, Y2DLocalExpansion,
-                HelmholtzConformingVolumeTaylorLocalExpansion,
-                LaplaceConformingVolumeTaylorLocalExpansion,
-                BiharmonicConformingVolumeTaylorLocalExpansion,
+                LinearPDEConformingVolumeTaylorLocalExpansion,
                 VolumeTaylorLocalExpansion)
 
         if (isinstance(base_kernel.get_base_kernel(), HelmholtzKernel)
@@ -772,27 +703,20 @@ class DefaultExpansionFactory(ExpansionFactoryBase):
         elif (isinstance(base_kernel.get_base_kernel(), YukawaKernel)
                 and base_kernel.dim == 2):
             return Y2DLocalExpansion
-        elif isinstance(base_kernel.get_base_kernel(), HelmholtzKernel):
-            return HelmholtzConformingVolumeTaylorLocalExpansion
-        elif isinstance(base_kernel.get_base_kernel(), LaplaceKernel):
-            return LaplaceConformingVolumeTaylorLocalExpansion
-        elif isinstance(base_kernel.get_base_kernel(),
-                (BiharmonicKernel, StokesletKernel, StressletKernel)):
-            return BiharmonicConformingVolumeTaylorLocalExpansion
-        else:
+        try:
+            base_kernel.get_base_kernel().get_pde_as_diff_op()
+            return LinearPDEConformingVolumeTaylorLocalExpansion
+        except NotImplementedError:
             return VolumeTaylorLocalExpansion
 
     def get_multipole_expansion_class(self, base_kernel):
         """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
-        from sumpy.kernel import (HelmholtzKernel, LaplaceKernel, YukawaKernel,
-                BiharmonicKernel, StokesletKernel, StressletKernel)
+        from sumpy.kernel import (HelmholtzKernel, YukawaKernel)
 
         from sumpy.expansion.multipole import (H2DMultipoleExpansion,
                 Y2DMultipoleExpansion,
-                LaplaceConformingVolumeTaylorMultipoleExpansion,
-                HelmholtzConformingVolumeTaylorMultipoleExpansion,
-                BiharmonicConformingVolumeTaylorMultipoleExpansion,
+                LinearPDEConformingVolumeTaylorMultipoleExpansion,
                 VolumeTaylorMultipoleExpansion)
 
         if (isinstance(base_kernel.get_base_kernel(), HelmholtzKernel)
@@ -801,14 +725,10 @@ class DefaultExpansionFactory(ExpansionFactoryBase):
         elif (isinstance(base_kernel.get_base_kernel(), YukawaKernel)
                 and base_kernel.dim == 2):
             return Y2DMultipoleExpansion
-        elif isinstance(base_kernel.get_base_kernel(), LaplaceKernel):
-            return LaplaceConformingVolumeTaylorMultipoleExpansion
-        elif isinstance(base_kernel.get_base_kernel(), HelmholtzKernel):
-            return HelmholtzConformingVolumeTaylorMultipoleExpansion
-        elif isinstance(base_kernel.get_base_kernel(),
-                (BiharmonicKernel, StokesletKernel, StressletKernel)):
-            return BiharmonicConformingVolumeTaylorMultipoleExpansion
-        else:
+        try:
+            base_kernel.get_base_kernel().get_pde_as_diff_op()
+            return LinearPDEConformingVolumeTaylorMultipoleExpansion
+        except NotImplementedError:
             return VolumeTaylorMultipoleExpansion
 
 # }}}
diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py
index e6e16d2a..59c945ab 100644
--- a/sumpy/expansion/local.py
+++ b/sumpy/expansion/local.py
@@ -24,9 +24,7 @@ import sumpy.symbolic as sym
 from sumpy.tools import add_to_sac
 
 from sumpy.expansion import (
-    ExpansionBase, VolumeTaylorExpansion, LaplaceConformingVolumeTaylorExpansion,
-    HelmholtzConformingVolumeTaylorExpansion,
-    BiharmonicConformingVolumeTaylorExpansion)
+    ExpansionBase, VolumeTaylorExpansion, LinearPDEConformingVolumeTaylorExpansion)
 
 from sumpy.tools import mi_increment_axis
 from pytools import single_valued
@@ -405,34 +403,22 @@ class VolumeTaylorLocalExpansion(
         VolumeTaylorExpansion.__init__(self, kernel, order, use_rscale)
 
 
-class LaplaceConformingVolumeTaylorLocalExpansion(
-        LaplaceConformingVolumeTaylorExpansion,
+class LinearPDEConformingVolumeTaylorLocalExpansion(
+        LinearPDEConformingVolumeTaylorExpansion,
         VolumeTaylorLocalExpansionBase):
 
     def __init__(self, kernel, order, use_rscale=None):
         VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale)
-        LaplaceConformingVolumeTaylorExpansion.__init__(
+        LinearPDEConformingVolumeTaylorExpansion.__init__(
                 self, kernel, order, use_rscale)
 
 
-class HelmholtzConformingVolumeTaylorLocalExpansion(
-        HelmholtzConformingVolumeTaylorExpansion,
-        VolumeTaylorLocalExpansionBase):
-
-    def __init__(self, kernel, order, use_rscale=None):
-        VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale)
-        HelmholtzConformingVolumeTaylorExpansion.__init__(
-                self, kernel, order, use_rscale)
-
-
-class BiharmonicConformingVolumeTaylorLocalExpansion(
-        BiharmonicConformingVolumeTaylorExpansion,
-        VolumeTaylorLocalExpansionBase):
-
-    def __init__(self, kernel, order, use_rscale=None):
-        VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale)
-        BiharmonicConformingVolumeTaylorExpansion.__init__(
-                self, kernel, order, use_rscale)
+LaplaceConformingVolumeTaylorLocalExpansion = \
+        LinearPDEConformingVolumeTaylorLocalExpansion
+HelmholtzConformingVolumeTaylorLocalExpansion = \
+        LinearPDEConformingVolumeTaylorLocalExpansion
+BiharmonicConformingVolumeTaylorLocalExpansion = \
+        LinearPDEConformingVolumeTaylorLocalExpansion
 
 # }}}
 
diff --git a/sumpy/expansion/multipole.py b/sumpy/expansion/multipole.py
index b9b1a077..957aec90 100644
--- a/sumpy/expansion/multipole.py
+++ b/sumpy/expansion/multipole.py
@@ -23,9 +23,7 @@ THE SOFTWARE.
 import sumpy.symbolic as sym  # noqa
 
 from sumpy.expansion import (
-    ExpansionBase, VolumeTaylorExpansion, LaplaceConformingVolumeTaylorExpansion,
-    HelmholtzConformingVolumeTaylorExpansion,
-    BiharmonicConformingVolumeTaylorExpansion)
+    ExpansionBase, VolumeTaylorExpansion, LinearPDEConformingVolumeTaylorExpansion)
 from pytools import factorial
 from sumpy.tools import mi_set_axis, add_to_sac
 
@@ -346,34 +344,22 @@ class VolumeTaylorMultipoleExpansion(
         VolumeTaylorExpansion.__init__(self, kernel, order, use_rscale)
 
 
-class LaplaceConformingVolumeTaylorMultipoleExpansion(
-        LaplaceConformingVolumeTaylorExpansion,
+class LinearPDEConformingVolumeTaylorMultipoleExpansion(
+        LinearPDEConformingVolumeTaylorExpansion,
         VolumeTaylorMultipoleExpansionBase):
 
     def __init__(self, kernel, order, use_rscale=None):
         VolumeTaylorMultipoleExpansionBase.__init__(self, kernel, order, use_rscale)
-        LaplaceConformingVolumeTaylorExpansion.__init__(
+        LinearPDEConformingVolumeTaylorExpansion.__init__(
                 self, kernel, order, use_rscale)
 
 
-class HelmholtzConformingVolumeTaylorMultipoleExpansion(
-        HelmholtzConformingVolumeTaylorExpansion,
-        VolumeTaylorMultipoleExpansionBase):
-
-    def __init__(self, kernel, order, use_rscale=None):
-        VolumeTaylorMultipoleExpansionBase.__init__(self, kernel, order, use_rscale)
-        HelmholtzConformingVolumeTaylorExpansion.__init__(
-                self, kernel, order, use_rscale)
-
-
-class BiharmonicConformingVolumeTaylorMultipoleExpansion(
-        BiharmonicConformingVolumeTaylorExpansion,
-        VolumeTaylorMultipoleExpansionBase):
-
-    def __init__(self, kernel, order, use_rscale=None):
-        VolumeTaylorMultipoleExpansionBase.__init__(self, kernel, order, use_rscale)
-        BiharmonicConformingVolumeTaylorExpansion.__init__(
-                self, kernel, order, use_rscale)
+LaplaceConformingVolumeTaylorMultipoleExpansion = \
+        LinearPDEConformingVolumeTaylorMultipoleExpansion
+HelmholtzConformingVolumeTaylorMultipoleExpansion = \
+        LinearPDEConformingVolumeTaylorMultipoleExpansion
+BiharmonicConformingVolumeTaylorMultipoleExpansion = \
+        LinearPDEConformingVolumeTaylorMultipoleExpansion
 
 # }}}
 
diff --git a/sumpy/kernel.py b/sumpy/kernel.py
index 7d123360..6f16b2cf 100644
--- a/sumpy/kernel.py
+++ b/sumpy/kernel.py
@@ -25,6 +25,7 @@ import loopy as lp
 import numpy as np
 from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin
 from sumpy.symbolic import pymbolic_real_norm_2
+import sumpy.symbolic as sym
 from pymbolic.primitives import make_sym_vector
 from pymbolic import var
 from collections import defaultdict
@@ -375,6 +376,14 @@ class ExpressionKernel(Kernel):
         from sumpy.tools import ExprDerivativeTaker
         return ExprDerivativeTaker(self.get_expression(dvec), dvec, rscale, sac)
 
+    def get_pde_as_diff_op(self):
+        r"""
+        Returns the PDE for the kernel as a
+        :class:`sumpy.expansion.diff_op.LinearPDESystemOperator` object `L`
+        where `L(u) = 0` is the PDE.
+        """
+        raise NotImplementedError
+
 
 one_kernel_2d = ExpressionKernel(
         dim=2,
@@ -427,6 +436,11 @@ class LaplaceKernel(ExpressionKernel):
         from sumpy.tools import LaplaceDerivativeTaker
         return LaplaceDerivativeTaker(self.get_expression(dvec), dvec, rscale, sac)
 
+    def get_pde_as_diff_op(self):
+        from sumpy.expansion.diff_op import make_identity_diff_op, laplacian
+        w = make_identity_diff_op(self.dim)
+        return laplacian(w)
+
 
 class BiharmonicKernel(ExpressionKernel):
     init_arg_names = ("dim",)
@@ -471,6 +485,11 @@ class BiharmonicKernel(ExpressionKernel):
         return RadialDerivativeTaker(self.get_expression(dvec), dvec, rscale,
                 sac)
 
+    def get_pde_as_diff_op(self):
+        from sumpy.expansion.diff_op import make_identity_diff_op, laplacian
+        w = make_identity_diff_op(self.dim)
+        return laplacian(laplacian(w))
+
 
 class HelmholtzKernel(ExpressionKernel):
     init_arg_names = ("dim", "helmholtz_k_name", "allow_evanescent")
@@ -548,6 +567,13 @@ class HelmholtzKernel(ExpressionKernel):
         from sumpy.tools import HelmholtzDerivativeTaker
         return HelmholtzDerivativeTaker(self.get_expression(dvec), dvec, rscale, sac)
 
+    def get_pde_as_diff_op(self):
+        from sumpy.expansion.diff_op import make_identity_diff_op, laplacian
+
+        w = make_identity_diff_op(self.dim)
+        k = sym.Symbol(self.helmholtz_k_name)
+        return (laplacian(w) + k**2 * w)
+
 
 class YukawaKernel(ExpressionKernel):
     init_arg_names = ("dim", "yukawa_lambda_name")
@@ -629,6 +655,12 @@ class YukawaKernel(ExpressionKernel):
         from sumpy.tools import HelmholtzDerivativeTaker
         return HelmholtzDerivativeTaker(self.get_expression(dvec), dvec, rscale, sac)
 
+    def get_pde_as_diff_op(self):
+        from sumpy.expansion.diff_op import make_identity_diff_op, laplacian
+        w = make_identity_diff_op(self.dim)
+        lam = sym.Symbol(self.yukawa_lambda_name)
+        return (laplacian(w) - lam**2 * w)
+
 
 class StokesletKernel(ExpressionKernel):
     init_arg_names = ("dim", "icomp", "jcomp", "viscosity_mu_name")
@@ -696,6 +728,11 @@ class StokesletKernel(ExpressionKernel):
 
     mapper_method = "map_stokeslet_kernel"
 
+    def get_pde_as_diff_op(self):
+        from sumpy.expansion.diff_op import make_identity_diff_op, laplacian
+        w = make_identity_diff_op(self.dim)
+        return laplacian(laplacian(w))
+
 
 class StressletKernel(ExpressionKernel):
     init_arg_names = ("dim", "icomp", "jcomp", "kcomp", "viscosity_mu_name")
@@ -765,6 +802,11 @@ class StressletKernel(ExpressionKernel):
 
     mapper_method = "map_stresslet_kernel"
 
+    def get_pde_as_diff_op(self):
+        from sumpy.expansion.diff_op import make_identity_diff_op, laplacian
+        w = make_identity_diff_op(self.dim)
+        return laplacian(laplacian(w))
+
 # }}}
 
 
-- 
GitLab