From 6a961e7e6a3cfbc8bf06702afc131a79e6c71d51 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sun, 18 Sep 2022 11:40:19 +0300
Subject: [PATCH] fixup inheritence and pylint in expansions

---
 sumpy/expansion/__init__.py  | 208 ++++++++++++++++++++++-------------
 sumpy/expansion/local.py     |  43 +++++---
 sumpy/expansion/m2l.py       |  57 +++++-----
 sumpy/expansion/multipole.py |  13 ++-
 4 files changed, 196 insertions(+), 125 deletions(-)

diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py
index 99ed13b3..ca974b93 100644
--- a/sumpy/expansion/__init__.py
+++ b/sumpy/expansion/__init__.py
@@ -20,14 +20,26 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-import logging
+from abc import ABC, abstractmethod
+from typing import Any, ClassVar, Dict, Hashable, List, Tuple, Type
+
 from pytools import memoize_method
+
 import sumpy.symbolic as sym
 from sumpy.tools import add_mi
-from typing import List, Tuple
+
+import logging
+logger = logging.getLogger(__name__)
+
 
 __doc__ = """
 .. autoclass:: ExpansionBase
+
+Expansion Wranglers
+^^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: ExpansionTermsWrangler
+.. autoclass:: FullExpansionTermsWrangler
 .. autoclass:: LinearPDEBasedExpansionTermsWrangler
 
 Expansion Factories
@@ -38,18 +50,24 @@ Expansion Factories
 .. autoclass:: VolumeTaylorExpansionFactory
 """
 
-logger = logging.getLogger(__name__)
-
 
-# {{{ base class
+# {{{ expansion base
 
-class ExpansionBase:
+class ExpansionBase(ABC):
     """
-    .. automethod:: with_kernel
-    .. automethod:: __len__
+    .. attribute:: kernel
+    .. attribute:: order
+    .. attribute:: use_rscale
+
     .. automethod:: get_coefficient_identifiers
     .. automethod:: coefficients_from_source
-    .. automethod:: translate_from
+    .. automethod:: coefficients_from_source_vec
+    .. automethod:: evaluate
+
+    .. automethod:: with_kernel
+    .. automethod:: copy
+
+    .. automethod:: __len__
     .. automethod:: __eq__
     .. automethod:: __ne__
     """
@@ -95,18 +113,15 @@ class ExpansionBase:
 
     # }}}
 
-    def with_kernel(self, kernel):
-        return type(self)(kernel, self.order, self.use_rscale)
-
-    def __len__(self):
-        return len(self.get_coefficient_identifiers())
+    # {{{ abstract interface
 
+    @abstractmethod
     def get_coefficient_identifiers(self):
         """
-        Returns the identifiers of the coefficients that actually get stored.
+        :returns: the identifiers of the coefficients that actually get stored.
         """
-        raise NotImplementedError
 
+    @abstractmethod
     def coefficients_from_source(self, kernel, avec, bvec, rscale, sac=None):
         """Form an expansion from a source point.
 
@@ -119,10 +134,9 @@ class ExpansionBase:
         :returns: a list of :mod:`sympy` expressions representing
             the coefficients of the expansion.
         """
-        raise NotImplementedError
 
-    def coefficients_from_source_vec(self, kernels, avec, bvec, rscale, weights,
-            sac=None):
+    def coefficients_from_source_vec(self,
+            kernels, avec, bvec, rscale, weights, sac=None):
         """Form an expansion with a linear combination of kernels and weights.
 
         :arg avec: vector from source to center.
@@ -141,34 +155,20 @@ class ExpansionBase:
                 result[i] += weight * coeffs[i]
         return result
 
+    @abstractmethod
     def evaluate(self, kernel, coeffs, bvec, rscale, sac=None):
         """
-        :return: a :mod:`sympy` expression corresponding
+        :returns: a :mod:`sympy` expression corresponding
             to the evaluated expansion with the coefficients
             in *coeffs*.
         """
 
-        raise NotImplementedError
-
-    def translate_from(self, src_expansion, src_coeff_exprs, src_rscale,
-            dvec, tgt_rscale, sac=None):
-        raise NotImplementedError
+    # }}}
 
-    def update_persistent_hash(self, key_hash, key_builder):
-        key_hash.update(type(self).__name__.encode("utf8"))
-        key_builder.rec(key_hash, self.kernel)
-        key_builder.rec(key_hash, self.order)
-        key_builder.rec(key_hash, self.use_rscale)
+    # {{{ copy
 
-    def __eq__(self, other):
-        return (
-                type(self) == type(other)
-                and self.kernel == other.kernel
-                and self.order == other.order
-                and self.use_rscale == other.use_rscale)
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
+    def with_kernel(self, kernel):
+        return type(self)(kernel, self.order, self.use_rscale)
 
     def copy(self, **kwargs):
         new_kwargs = {
@@ -184,14 +184,44 @@ class ExpansionBase:
 
         return type(self)(**new_kwargs)
 
+    # }}}
+
+    def update_persistent_hash(self, key_hash, key_builder):
+        key_hash.update(type(self).__name__.encode("utf8"))
+        key_builder.rec(key_hash, self.kernel)
+        key_builder.rec(key_hash, self.order)
+        key_builder.rec(key_hash, self.use_rscale)
+
+    def __len__(self):
+        return len(self.get_coefficient_identifiers())
+
+    def __eq__(self, other):
+        return (
+                type(self) == type(other)
+                and self.kernel == other.kernel
+                and self.order == other.order
+                and self.use_rscale == other.use_rscale)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
 
 # }}}
 
 
 # {{{ expansion terms wrangler
 
-class ExpansionTermsWrangler:
+class ExpansionTermsWrangler(ABC):
+    """
+    .. attribute:: order
+    .. attribute:: dim
+    .. attribute:: max_mi
 
+    .. automethod:: get_coefficient_identifiers
+    .. automethod:: get_full_kernel_derivatives_from_stored
+    .. automethod:: get_stored_mpole_coefficients_from_full
+
+    .. automethod:: get_full_coefficient_identifiers
+    """
     init_arg_names = ("order", "dim", "max_mi")
 
     def __init__(self, order, dim, max_mi=None):
@@ -199,16 +229,23 @@ class ExpansionTermsWrangler:
         self.dim = dim
         self.max_mi = max_mi
 
+    # {{{ abstract interface
+
+    @abstractmethod
     def get_coefficient_identifiers(self):
-        raise NotImplementedError
+        pass
+
+    @abstractmethod
+    def get_full_kernel_derivatives_from_stored(self,
+            stored_kernel_derivatives, rscale, sac=None):
+        pass
 
-    def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives,
-            rscale, sac=None):
-        raise NotImplementedError
+    @abstractmethod
+    def get_stored_mpole_coefficients_from_full(self,
+            full_mpole_coefficients, rscale, sac=None):
+        pass
 
-    def get_stored_mpole_coefficients_from_full(self, full_mpole_coefficients,
-            rscale, sac=None):
-        raise NotImplementedError
+    # }}}
 
     @memoize_method
     def get_full_coefficient_identifiers(self):
@@ -241,6 +278,8 @@ class ExpansionTermsWrangler:
 
         return type(self)(**new_kwargs)
 
+    # {{{ hyperplane helpers
+
     def _get_mi_hyperpplanes(self) -> List[Tuple[int, int]]:
         r"""
         Coefficient storage is organized into "hyperplanes" in multi-index
@@ -303,18 +342,23 @@ class ExpansionTermsWrangler:
 
         return res
 
+    # }}}
+
 
 class FullExpansionTermsWrangler(ExpansionTermsWrangler):
+    def get_coefficient_identifiers(self):
+        return super().get_full_coefficient_identifiers()
 
-    get_coefficient_identifiers = (
-            ExpansionTermsWrangler.get_full_coefficient_identifiers)
-
-    def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives,
-            rscale, sac=None):
+    def get_full_kernel_derivatives_from_stored(self,
+            stored_kernel_derivatives, rscale, sac=None):
         return stored_kernel_derivatives
 
-    get_stored_mpole_coefficients_from_full = (
-            get_full_kernel_derivatives_from_stored)
+    def get_stored_mpole_coefficients_from_full(self,
+            full_mpole_coefficients, rscale, sac=None):
+        return self.get_full_kernel_derivatives_from_stored(
+            full_mpole_coefficients, rscale, sac=sac)
+
+# }}}
 
 
 # {{{ sparse matrix-vector multiplication
@@ -392,6 +436,8 @@ class CSEMatVecOperator:
 # }}}
 
 
+# {{{ LinearPDEBasedExpansionTermsWrangler
+
 class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
     """
     .. automethod:: __init__
@@ -411,18 +457,18 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
     def get_coefficient_identifiers(self):
         return self.stored_identifiers
 
-    def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives,
-            rscale, sac=None):
-
+    def get_full_kernel_derivatives_from_stored(self,
+            stored_kernel_derivatives, rscale, sac=None):
         from sumpy.tools import add_to_sac
+
         projection_matrix = self.get_projection_matrix(rscale)
         return projection_matrix.matvec(stored_kernel_derivatives,
                 lambda x: add_to_sac(sac, x))
 
-    def get_stored_mpole_coefficients_from_full(self, full_mpole_coefficients,
-            rscale, sac=None):
-
+    def get_stored_mpole_coefficients_from_full(self,
+            full_mpole_coefficients, rscale, sac=None):
         from sumpy.tools import add_to_sac
+
         projection_matrix = self.get_projection_matrix(rscale)
         return projection_matrix.transpose_matvec(full_mpole_coefficients,
                 lambda x: add_to_sac(sac, x))
@@ -636,9 +682,13 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
 # }}}
 
 
-# {{{ volume taylor
+# {{{ volume taylor expansion
+
+# FIXME: This is called an expansion but doesn't inherit from ExpansionBase?
 
-class VolumeTaylorExpansionBase:
+class VolumeTaylorExpansionMixin:
+    expansion_terms_wrangler_class: ClassVar[Type[ExpansionTermsWrangler]]
+    expansion_terms_wrangler_cache: ClassVar[Dict[Hashable, Any]] = {}
 
     @classmethod
     def get_or_make_expansion_terms_wrangler(cls, *key):
@@ -679,20 +729,16 @@ class VolumeTaylorExpansionBase:
         return self._storage_loc_dict[i]
 
 
-class VolumeTaylorExpansion(VolumeTaylorExpansionBase):
-
+class VolumeTaylorExpansion(VolumeTaylorExpansionMixin):
     expansion_terms_wrangler_class = FullExpansionTermsWrangler
-    expansion_terms_wrangler_cache = {}
 
     # not user-facing, be strict about having to pass use_rscale
     def __init__(self, kernel, order, use_rscale):
         self.expansion_terms_wrangler_key = (order, kernel.dim)
 
 
-class LinearPDEConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase):
-
+class LinearPDEConformingVolumeTaylorExpansion(VolumeTaylorExpansionMixin):
     expansion_terms_wrangler_class = LinearPDEBasedExpansionTermsWrangler
-    expansion_terms_wrangler_cache = {}
 
     # not user-facing, be strict about having to pass use_rscale
     def __init__(self, kernel, order, use_rscale):
@@ -736,21 +782,23 @@ class BiharmonicConformingVolumeTaylorExpansion(
 
 # {{{ expansion factory
 
-class ExpansionFactoryBase:
-    """An interface
+class ExpansionFactoryBase(ABC):
+    """
     .. automethod:: get_local_expansion_class
     .. automethod:: get_multipole_expansion_class
     """
 
+    @abstractmethod
     def get_local_expansion_class(self, base_kernel):
-        """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
-        raise NotImplementedError()
+        :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
+        """
 
+    @abstractmethod
     def get_multipole_expansion_class(self, base_kernel):
-        """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
-        raise NotImplementedError()
+        :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
+        """
 
 
 class VolumeTaylorExpansionFactory(ExpansionFactoryBase):
@@ -759,13 +807,15 @@ class VolumeTaylorExpansionFactory(ExpansionFactoryBase):
     """
 
     def get_local_expansion_class(self, base_kernel):
-        """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
+        """
+        :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
         from sumpy.expansion.local import VolumeTaylorLocalExpansion
         return VolumeTaylorLocalExpansion
 
     def get_multipole_expansion_class(self, base_kernel):
-        """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
+        """
+        :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
         from sumpy.expansion.multipole import VolumeTaylorMultipoleExpansion
         return VolumeTaylorMultipoleExpansion
@@ -777,7 +827,8 @@ class DefaultExpansionFactory(ExpansionFactoryBase):
     """
 
     def get_local_expansion_class(self, base_kernel):
-        """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
+        """
+        :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
         from sumpy.expansion.local import (
                 LinearPDEConformingVolumeTaylorLocalExpansion,
@@ -789,7 +840,8 @@ class DefaultExpansionFactory(ExpansionFactoryBase):
             return VolumeTaylorLocalExpansion
 
     def get_multipole_expansion_class(self, base_kernel):
-        """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
+        """
+        :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
         """
         from sumpy.expansion.multipole import (
                 LinearPDEConformingVolumeTaylorMultipoleExpansion,
diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py
index 64919b17..58809641 100644
--- a/sumpy/expansion/local.py
+++ b/sumpy/expansion/local.py
@@ -21,6 +21,7 @@ THE SOFTWARE.
 """
 
 import math
+from abc import abstractmethod
 
 from pytools import single_valued
 
@@ -28,6 +29,7 @@ import sumpy.symbolic as sym
 from sumpy.expansion import (
         ExpansionBase,
         VolumeTaylorExpansion,
+        VolumeTaylorExpansionMixin,
         LinearPDEConformingVolumeTaylorExpansion)
 from sumpy.tools import add_to_sac, mi_increment_axis
 
@@ -48,12 +50,9 @@ __doc__ = """
 class LocalExpansionBase(ExpansionBase):
     """Base class for local expansions.
 
-    .. attribute:: kernel
-    .. attribute:: order
-    .. attribute:: use_rscale
-
     .. automethod:: translate_from
     """
+
     init_arg_names = ("kernel", "order", "use_rscale", "m2l_translation")
 
     def __init__(self, kernel, order, use_rscale=None,
@@ -78,6 +77,7 @@ class LocalExpansionBase(ExpansionBase):
             and self.m2l_translation == other.m2l_translation
         )
 
+    @abstractmethod
     def translate_from(self, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, m2l_translation_classes_dependent_data=None):
         """Translate from a multipole or local expansion to a local expansion
@@ -96,13 +96,11 @@ class LocalExpansionBase(ExpansionBase):
                 expressions representing the expressions returned by
                 :func:`~sumpy.expansion.m2l.M2LTranslationBase.translation_classes_dependent_data`.
         """
-        raise NotImplementedError
 
 
 # {{{ line taylor
 
 class LineTaylorLocalExpansion(LocalExpansionBase):
-
     def get_storage_index(self, k):
         return k
 
@@ -150,12 +148,16 @@ class LineTaylorLocalExpansion(LocalExpansionBase):
                 coeffs[self.get_storage_index(i)] / math.factorial(i)
                 for i in self.get_coefficient_identifiers()))
 
+    def translate_from(self, src_expansion, src_coeff_exprs, src_rscale,
+            dvec, tgt_rscale, sac=None, m2l_translation_classes_dependent_data=None):
+        raise NotImplementedError
+
 # }}}
 
 
 # {{{ volume taylor
 
-class VolumeTaylorLocalExpansionBase(LocalExpansionBase):
+class VolumeTaylorLocalExpansionBase(VolumeTaylorExpansionMixin, LocalExpansionBase):
     """
     Coefficients represent derivative values of the kernel.
     """
@@ -463,14 +465,21 @@ class BiharmonicConformingVolumeTaylorLocalExpansion(
 # {{{ 2D Bessel-based-expansion
 
 class _FourierBesselLocalExpansion(LocalExpansionBase):
-    def __init__(self, kernel, order, use_rscale=None,
-            m2l_translation=None):
+    def __init__(self,
+            kernel, order, mpole_expn_class,
+            use_rscale=None, m2l_translation=None):
         if not m2l_translation:
             from sumpy.expansion.m2l import DefaultM2LTranslationClassFactory
             factory = DefaultM2LTranslationClassFactory()
-            m2l_translation = factory.get_m2l_translation_class(kernel,
-                self.__class__)()
+            m2l_translation = (
+                factory.get_m2l_translation_class(kernel, self.__class__)())
+
         super().__init__(kernel, order, use_rscale, m2l_translation)
+        self.mpole_expn_class = mpole_expn_class
+
+    @abstractmethod
+    def get_bessel_arg_scaling(self):
+        pass
 
     def get_storage_index(self, k):
         return self.order+k
@@ -561,10 +570,10 @@ class H2DLocalExpansion(_FourierBesselLocalExpansion):
         assert (isinstance(kernel.get_base_kernel(), HelmholtzKernel)
                 and kernel.dim == 2)
 
-        super().__init__(kernel, order, use_rscale, m2l_translation=m2l_translation)
-
         from sumpy.expansion.multipole import H2DMultipoleExpansion
-        self.mpole_expn_class = H2DMultipoleExpansion
+        super().__init__(kernel, order, H2DMultipoleExpansion,
+            use_rscale=use_rscale,
+            m2l_translation=m2l_translation)
 
     def get_bessel_arg_scaling(self):
         return sym.Symbol(self.kernel.get_base_kernel().helmholtz_k_name)
@@ -576,10 +585,10 @@ class Y2DLocalExpansion(_FourierBesselLocalExpansion):
         assert (isinstance(kernel.get_base_kernel(), YukawaKernel)
                 and kernel.dim == 2)
 
-        super().__init__(kernel, order, use_rscale, m2l_translation=m2l_translation)
-
         from sumpy.expansion.multipole import Y2DMultipoleExpansion
-        self.mpole_expn_class = Y2DMultipoleExpansion
+        super().__init__(kernel, order, Y2DMultipoleExpansion,
+            use_rscale=use_rscale,
+            m2l_translation=m2l_translation)
 
     def get_bessel_arg_scaling(self):
         return sym.I * sym.Symbol(self.kernel.get_base_kernel().yukawa_lambda_name)
diff --git a/sumpy/expansion/m2l.py b/sumpy/expansion/m2l.py
index e7c4793f..3d4b89e7 100644
--- a/sumpy/expansion/m2l.py
+++ b/sumpy/expansion/m2l.py
@@ -20,7 +20,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Tuple, Any
+from abc import ABC, abstractmethod
+from typing import Any, ClassVar, Tuple
 
 import pymbolic
 import loopy as lp
@@ -47,16 +48,16 @@ __doc__ = """
 
 # {{{ M2L translation factory
 
-class M2LTranslationClassFactoryBase:
-    """An interface
+class M2LTranslationClassFactoryBase(ABC):
+    """
     .. automethod:: get_m2l_translation_class
     """
 
+    @abstractmethod
     def get_m2l_translation_class(self, base_kernel, local_expansion_class):
         """Returns a subclass of :class:`M2LTranslationBase` suitable for
         *base_kernel* and *local_expansion_class*.
         """
-        raise NotImplementedError()
 
 
 class NonFFTM2LTranslationClassFactory(M2LTranslationClassFactoryBase):
@@ -113,12 +114,12 @@ class DefaultM2LTranslationClassFactory(M2LTranslationClassFactoryBase):
             raise RuntimeError(
                 f"Unknown local_expansion_class: {local_expansion_class}")
 
-
 # }}}
 
+
 # {{{ M2LTranslationBase
 
-class M2LTranslationBase:
+class M2LTranslationBase(ABC):
     """Base class for Multipole to Local Translation
 
     .. automethod:: translate
@@ -133,10 +134,10 @@ class M2LTranslationBase:
     .. autoattribute:: use_preprocessing
     """
 
-    use_fft = False
-    use_preprocessing = False
+    use_fft: ClassVar[bool] = False
+    use_preprocessing: ClassVar[bool] = False
 
-    def __setattr__(self):
+    def __setattr__(self, name, value):
         # These are intended to be stateless.
         raise AttributeError(f"{type(self)} is stateless and does not permit "
                 "attribute modification.")
@@ -144,9 +145,10 @@ class M2LTranslationBase:
     def __eq__(self, other):
         return type(self) is type(other)
 
+    @abstractmethod
     def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None):
-        raise NotImplementedError
+        pass
 
     def loopy_translate(self, tgt_expansion, src_expansion):
         raise NotImplementedError(
@@ -190,6 +192,7 @@ class M2LTranslationBase:
         return translation_classes_dependent_data_loopy_knl(tgt_expansion,
             src_expansion, result_dtype)
 
+    @abstractmethod
     def preprocess_multipole_exprs(self, tgt_expansion, src_expansion,
             src_coeff_exprs, sac, src_rscale):
         """Return the preprocessed multipole expansion for an optimized M2L.
@@ -203,7 +206,6 @@ class M2LTranslationBase:
         expansion coefficients with zeros added to make the M2L computation a
         circulant matvec.
         """
-        raise NotImplementedError
 
     def preprocess_multipole_nexprs(self, tgt_expansion, src_expansion):
         """Return the number of expressions returned by
@@ -217,6 +219,7 @@ class M2LTranslationBase:
         return self.translation_classes_dependent_ndata(tgt_expansion,
             src_expansion)
 
+    @abstractmethod
     def postprocess_local_exprs(self, tgt_expansion, src_expansion, m2l_result,
             src_rscale, tgt_rscale, sac):
         """Return postprocessed local expansion for an optimized M2L.
@@ -227,7 +230,6 @@ class M2LTranslationBase:
         When FFT is turned on, the output expressions are assumed to have been
         transformed from Fourier space back to the original space by the caller.
         """
-        raise NotImplementedError
 
     def postprocess_local_nexprs(self, tgt_expansion, src_expansion):
         """Return the number of expressions given as input to
@@ -244,9 +246,9 @@ class M2LTranslationBase:
     def update_persistent_hash(self, key_hash, key_builder):
         key_hash.update(type(self).__name__.encode("utf8"))
 
-
 # }}} M2LTranslationBase
 
+
 # {{{ VolumeTaylorM2LTranslation
 
 class VolumeTaylorM2LTranslation(M2LTranslationBase):
@@ -627,13 +629,13 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase):
             fixed_parameters=fixed_parameters,
         )
 
-
 # }}} VolumeTaylorM2LTranslation
 
+
 # {{{ VolumeTaylorM2LWithPreprocessedMultipoles
 
 class VolumeTaylorM2LWithPreprocessedMultipoles(VolumeTaylorM2LTranslation):
-    use_preprocessing = True
+    use_preprocessing: ClassVar[bool] = True
 
     def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None):
@@ -690,13 +692,13 @@ class VolumeTaylorM2LWithPreprocessedMultipoles(VolumeTaylorM2LTranslation):
                 lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
                 )
 
-
 # }}} VolumeTaylorM2LWithPreprocessedMultipoles
 
+
 # {{{ VolumeTaylorM2LWithFFT
 
 class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles):
-    use_fft = True
+    use_fft: ClassVar[bool] = True
 
     def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None):
@@ -738,9 +740,9 @@ class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles):
         return super().postprocess_local_exprs(tgt_expansion,
             src_expansion, m2l_result, src_rscale, tgt_rscale, sac)
 
-
 # }}} VolumeTaylorM2LWithFFT
 
+
 # {{{ FourierBesselM2LTranslation
 
 class FourierBesselM2LTranslation(M2LTranslationBase):
@@ -823,13 +825,13 @@ class FourierBesselM2LTranslation(M2LTranslationBase):
     def postprocess_local_nexprs(self, tgt_expansion, src_expansion):
         return 2*tgt_expansion.order + 1
 
-
 # }}} FourierBesselM2LTranslation
 
+
 # {{{ FourierBesselM2LWithPreprocessedMultipoles
 
 class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation):
-    use_preprocessing = True
+    use_preprocessing: ClassVar[bool] = True
 
     def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None):
@@ -846,8 +848,8 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation):
         return translated_coeffs
 
     def loopy_translate(self, tgt_expansion, src_expansion):
-        ncoeff_src = self.preprocess_multipole_nexprs(src_expansion)
-        ncoeff_tgt = self.postprocess_local_nexprs(src_expansion)
+        ncoeff_src = self.preprocess_multipole_nexprs(tgt_expansion, src_expansion)
+        ncoeff_tgt = self.postprocess_local_nexprs(tgt_expansion, src_expansion)
 
         icoeff_src = pymbolic.var("icoeff_src")
         icoeff_tgt = pymbolic.var("icoeff_tgt")
@@ -857,7 +859,7 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation):
         src_coeffs = pymbolic.var("src_coeffs")
         translation_classes_dependent_data = pymbolic.var("data")
 
-        if self.use_fft_for_m2l:
+        if self.use_fft:
             expr = src_coeffs[icoeff_tgt] \
                     * translation_classes_dependent_data[icoeff_tgt]
         else:
@@ -884,13 +886,13 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation):
                 lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
                 )
 
-
 # }}} FourierBesselM2LWithPreprocessedMultipoles
 
+
 # {{{ FourierBesselM2LWithFFT
 
 class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles):
-    use_fft = True
+    use_fft: ClassVar[bool] = True
 
     def __init__(self):
         # FIXME: expansion with FFT is correct symbolically and can be verified
@@ -899,8 +901,7 @@ class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles):
         # instability but gives rscale as a possible solution. Sumpy's rscale
         # choice is slightly different from Greengard and Rokhlin and that
         # might be the reason for this numerical issue.
-        raise ValueError("Bessel based expansions with FFT is not fully "
-                         "supported yet.")
+        raise ValueError("Bessel based expansions with FFT are not supported yet.")
 
     def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale,
             dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None):
@@ -956,9 +957,9 @@ class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles):
         return super().postprocess_local_exprs(tgt_expansion,
             src_expansion, m2l_result, src_rscale, tgt_rscale, sac)
 
-
 # }}} FourierBesselM2LWithFFT
 
+
 # {{{ translation_classes_dependent_data_loopy_knl
 
 def translation_classes_dependent_data_loopy_knl(tgt_expansion, src_expansion,
diff --git a/sumpy/expansion/multipole.py b/sumpy/expansion/multipole.py
index 1924e6c7..a344fe0a 100644
--- a/sumpy/expansion/multipole.py
+++ b/sumpy/expansion/multipole.py
@@ -21,10 +21,14 @@ THE SOFTWARE.
 """
 
 import math
+from abc import abstractmethod
 
 import sumpy.symbolic as sym
 from sumpy.expansion import (
-    ExpansionBase, VolumeTaylorExpansion, LinearPDEConformingVolumeTaylorExpansion)
+    ExpansionBase,
+    VolumeTaylorExpansion,
+    VolumeTaylorExpansionMixin,
+    LinearPDEConformingVolumeTaylorExpansion)
 from sumpy.tools import mi_set_axis, add_to_sac, mi_power, mi_factorial
 
 import logging
@@ -46,7 +50,8 @@ class MultipoleExpansionBase(ExpansionBase):
 
 # {{{ volume taylor
 
-class VolumeTaylorMultipoleExpansionBase(MultipoleExpansionBase):
+class VolumeTaylorMultipoleExpansionBase(
+        VolumeTaylorExpansionMixin, MultipoleExpansionBase):
     """
     Coefficients represent the terms in front of the kernel derivatives.
     """
@@ -396,6 +401,10 @@ class BiharmonicConformingVolumeTaylorMultipoleExpansion(
 # {{{ 2D Hankel-based expansions
 
 class _HankelBased2DMultipoleExpansion(MultipoleExpansionBase):
+    @abstractmethod
+    def get_bessel_arg_scaling(self):
+        return
+
     def get_storage_index(self, k):
         return self.order+k
 
-- 
GitLab