From 6a961e7e6a3cfbc8bf06702afc131a79e6c71d51 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 18 Sep 2022 11:40:19 +0300 Subject: [PATCH] fixup inheritence and pylint in expansions --- sumpy/expansion/__init__.py | 208 ++++++++++++++++++++++------------- sumpy/expansion/local.py | 43 +++++--- sumpy/expansion/m2l.py | 57 +++++----- sumpy/expansion/multipole.py | 13 ++- 4 files changed, 196 insertions(+), 125 deletions(-) diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py index 99ed13b3..ca974b93 100644 --- a/sumpy/expansion/__init__.py +++ b/sumpy/expansion/__init__.py @@ -20,14 +20,26 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import logging +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Dict, Hashable, List, Tuple, Type + from pytools import memoize_method + import sumpy.symbolic as sym from sumpy.tools import add_mi -from typing import List, Tuple + +import logging +logger = logging.getLogger(__name__) + __doc__ = """ .. autoclass:: ExpansionBase + +Expansion Wranglers +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: ExpansionTermsWrangler +.. autoclass:: FullExpansionTermsWrangler .. autoclass:: LinearPDEBasedExpansionTermsWrangler Expansion Factories @@ -38,18 +50,24 @@ Expansion Factories .. autoclass:: VolumeTaylorExpansionFactory """ -logger = logging.getLogger(__name__) - -# {{{ base class +# {{{ expansion base -class ExpansionBase: +class ExpansionBase(ABC): """ - .. automethod:: with_kernel - .. automethod:: __len__ + .. attribute:: kernel + .. attribute:: order + .. attribute:: use_rscale + .. automethod:: get_coefficient_identifiers .. automethod:: coefficients_from_source - .. automethod:: translate_from + .. automethod:: coefficients_from_source_vec + .. automethod:: evaluate + + .. automethod:: with_kernel + .. automethod:: copy + + .. automethod:: __len__ .. automethod:: __eq__ .. automethod:: __ne__ """ @@ -95,18 +113,15 @@ class ExpansionBase: # }}} - def with_kernel(self, kernel): - return type(self)(kernel, self.order, self.use_rscale) - - def __len__(self): - return len(self.get_coefficient_identifiers()) + # {{{ abstract interface + @abstractmethod def get_coefficient_identifiers(self): """ - Returns the identifiers of the coefficients that actually get stored. + :returns: the identifiers of the coefficients that actually get stored. """ - raise NotImplementedError + @abstractmethod def coefficients_from_source(self, kernel, avec, bvec, rscale, sac=None): """Form an expansion from a source point. @@ -119,10 +134,9 @@ class ExpansionBase: :returns: a list of :mod:`sympy` expressions representing the coefficients of the expansion. """ - raise NotImplementedError - def coefficients_from_source_vec(self, kernels, avec, bvec, rscale, weights, - sac=None): + def coefficients_from_source_vec(self, + kernels, avec, bvec, rscale, weights, sac=None): """Form an expansion with a linear combination of kernels and weights. :arg avec: vector from source to center. @@ -141,34 +155,20 @@ class ExpansionBase: result[i] += weight * coeffs[i] return result + @abstractmethod def evaluate(self, kernel, coeffs, bvec, rscale, sac=None): """ - :return: a :mod:`sympy` expression corresponding + :returns: a :mod:`sympy` expression corresponding to the evaluated expansion with the coefficients in *coeffs*. """ - raise NotImplementedError - - def translate_from(self, src_expansion, src_coeff_exprs, src_rscale, - dvec, tgt_rscale, sac=None): - raise NotImplementedError + # }}} - def update_persistent_hash(self, key_hash, key_builder): - key_hash.update(type(self).__name__.encode("utf8")) - key_builder.rec(key_hash, self.kernel) - key_builder.rec(key_hash, self.order) - key_builder.rec(key_hash, self.use_rscale) + # {{{ copy - def __eq__(self, other): - return ( - type(self) == type(other) - and self.kernel == other.kernel - and self.order == other.order - and self.use_rscale == other.use_rscale) - - def __ne__(self, other): - return not self.__eq__(other) + def with_kernel(self, kernel): + return type(self)(kernel, self.order, self.use_rscale) def copy(self, **kwargs): new_kwargs = { @@ -184,14 +184,44 @@ class ExpansionBase: return type(self)(**new_kwargs) + # }}} + + def update_persistent_hash(self, key_hash, key_builder): + key_hash.update(type(self).__name__.encode("utf8")) + key_builder.rec(key_hash, self.kernel) + key_builder.rec(key_hash, self.order) + key_builder.rec(key_hash, self.use_rscale) + + def __len__(self): + return len(self.get_coefficient_identifiers()) + + def __eq__(self, other): + return ( + type(self) == type(other) + and self.kernel == other.kernel + and self.order == other.order + and self.use_rscale == other.use_rscale) + + def __ne__(self, other): + return not self.__eq__(other) # }}} # {{{ expansion terms wrangler -class ExpansionTermsWrangler: +class ExpansionTermsWrangler(ABC): + """ + .. attribute:: order + .. attribute:: dim + .. attribute:: max_mi + .. automethod:: get_coefficient_identifiers + .. automethod:: get_full_kernel_derivatives_from_stored + .. automethod:: get_stored_mpole_coefficients_from_full + + .. automethod:: get_full_coefficient_identifiers + """ init_arg_names = ("order", "dim", "max_mi") def __init__(self, order, dim, max_mi=None): @@ -199,16 +229,23 @@ class ExpansionTermsWrangler: self.dim = dim self.max_mi = max_mi + # {{{ abstract interface + + @abstractmethod def get_coefficient_identifiers(self): - raise NotImplementedError + pass + + @abstractmethod + def get_full_kernel_derivatives_from_stored(self, + stored_kernel_derivatives, rscale, sac=None): + pass - def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives, - rscale, sac=None): - raise NotImplementedError + @abstractmethod + def get_stored_mpole_coefficients_from_full(self, + full_mpole_coefficients, rscale, sac=None): + pass - def get_stored_mpole_coefficients_from_full(self, full_mpole_coefficients, - rscale, sac=None): - raise NotImplementedError + # }}} @memoize_method def get_full_coefficient_identifiers(self): @@ -241,6 +278,8 @@ class ExpansionTermsWrangler: return type(self)(**new_kwargs) + # {{{ hyperplane helpers + def _get_mi_hyperpplanes(self) -> List[Tuple[int, int]]: r""" Coefficient storage is organized into "hyperplanes" in multi-index @@ -303,18 +342,23 @@ class ExpansionTermsWrangler: return res + # }}} + class FullExpansionTermsWrangler(ExpansionTermsWrangler): + def get_coefficient_identifiers(self): + return super().get_full_coefficient_identifiers() - get_coefficient_identifiers = ( - ExpansionTermsWrangler.get_full_coefficient_identifiers) - - def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives, - rscale, sac=None): + def get_full_kernel_derivatives_from_stored(self, + stored_kernel_derivatives, rscale, sac=None): return stored_kernel_derivatives - get_stored_mpole_coefficients_from_full = ( - get_full_kernel_derivatives_from_stored) + def get_stored_mpole_coefficients_from_full(self, + full_mpole_coefficients, rscale, sac=None): + return self.get_full_kernel_derivatives_from_stored( + full_mpole_coefficients, rscale, sac=sac) + +# }}} # {{{ sparse matrix-vector multiplication @@ -392,6 +436,8 @@ class CSEMatVecOperator: # }}} +# {{{ LinearPDEBasedExpansionTermsWrangler + class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler): """ .. automethod:: __init__ @@ -411,18 +457,18 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler): def get_coefficient_identifiers(self): return self.stored_identifiers - def get_full_kernel_derivatives_from_stored(self, stored_kernel_derivatives, - rscale, sac=None): - + def get_full_kernel_derivatives_from_stored(self, + stored_kernel_derivatives, rscale, sac=None): from sumpy.tools import add_to_sac + projection_matrix = self.get_projection_matrix(rscale) return projection_matrix.matvec(stored_kernel_derivatives, lambda x: add_to_sac(sac, x)) - def get_stored_mpole_coefficients_from_full(self, full_mpole_coefficients, - rscale, sac=None): - + def get_stored_mpole_coefficients_from_full(self, + full_mpole_coefficients, rscale, sac=None): from sumpy.tools import add_to_sac + projection_matrix = self.get_projection_matrix(rscale) return projection_matrix.transpose_matvec(full_mpole_coefficients, lambda x: add_to_sac(sac, x)) @@ -636,9 +682,13 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler): # }}} -# {{{ volume taylor +# {{{ volume taylor expansion + +# FIXME: This is called an expansion but doesn't inherit from ExpansionBase? -class VolumeTaylorExpansionBase: +class VolumeTaylorExpansionMixin: + expansion_terms_wrangler_class: ClassVar[Type[ExpansionTermsWrangler]] + expansion_terms_wrangler_cache: ClassVar[Dict[Hashable, Any]] = {} @classmethod def get_or_make_expansion_terms_wrangler(cls, *key): @@ -679,20 +729,16 @@ class VolumeTaylorExpansionBase: return self._storage_loc_dict[i] -class VolumeTaylorExpansion(VolumeTaylorExpansionBase): - +class VolumeTaylorExpansion(VolumeTaylorExpansionMixin): expansion_terms_wrangler_class = FullExpansionTermsWrangler - expansion_terms_wrangler_cache = {} # not user-facing, be strict about having to pass use_rscale def __init__(self, kernel, order, use_rscale): self.expansion_terms_wrangler_key = (order, kernel.dim) -class LinearPDEConformingVolumeTaylorExpansion(VolumeTaylorExpansionBase): - +class LinearPDEConformingVolumeTaylorExpansion(VolumeTaylorExpansionMixin): expansion_terms_wrangler_class = LinearPDEBasedExpansionTermsWrangler - expansion_terms_wrangler_cache = {} # not user-facing, be strict about having to pass use_rscale def __init__(self, kernel, order, use_rscale): @@ -736,21 +782,23 @@ class BiharmonicConformingVolumeTaylorExpansion( # {{{ expansion factory -class ExpansionFactoryBase: - """An interface +class ExpansionFactoryBase(ABC): + """ .. automethod:: get_local_expansion_class .. automethod:: get_multipole_expansion_class """ + @abstractmethod def get_local_expansion_class(self, base_kernel): - """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*. """ - raise NotImplementedError() + :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*. + """ + @abstractmethod def get_multipole_expansion_class(self, base_kernel): - """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*. """ - raise NotImplementedError() + :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*. + """ class VolumeTaylorExpansionFactory(ExpansionFactoryBase): @@ -759,13 +807,15 @@ class VolumeTaylorExpansionFactory(ExpansionFactoryBase): """ def get_local_expansion_class(self, base_kernel): - """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*. + """ + :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*. """ from sumpy.expansion.local import VolumeTaylorLocalExpansion return VolumeTaylorLocalExpansion def get_multipole_expansion_class(self, base_kernel): - """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*. + """ + :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*. """ from sumpy.expansion.multipole import VolumeTaylorMultipoleExpansion return VolumeTaylorMultipoleExpansion @@ -777,7 +827,8 @@ class DefaultExpansionFactory(ExpansionFactoryBase): """ def get_local_expansion_class(self, base_kernel): - """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*. + """ + :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*. """ from sumpy.expansion.local import ( LinearPDEConformingVolumeTaylorLocalExpansion, @@ -789,7 +840,8 @@ class DefaultExpansionFactory(ExpansionFactoryBase): return VolumeTaylorLocalExpansion def get_multipole_expansion_class(self, base_kernel): - """Returns a subclass of :class:`ExpansionBase` suitable for *base_kernel*. + """ + :returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*. """ from sumpy.expansion.multipole import ( LinearPDEConformingVolumeTaylorMultipoleExpansion, diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py index 64919b17..58809641 100644 --- a/sumpy/expansion/local.py +++ b/sumpy/expansion/local.py @@ -21,6 +21,7 @@ THE SOFTWARE. """ import math +from abc import abstractmethod from pytools import single_valued @@ -28,6 +29,7 @@ import sumpy.symbolic as sym from sumpy.expansion import ( ExpansionBase, VolumeTaylorExpansion, + VolumeTaylorExpansionMixin, LinearPDEConformingVolumeTaylorExpansion) from sumpy.tools import add_to_sac, mi_increment_axis @@ -48,12 +50,9 @@ __doc__ = """ class LocalExpansionBase(ExpansionBase): """Base class for local expansions. - .. attribute:: kernel - .. attribute:: order - .. attribute:: use_rscale - .. automethod:: translate_from """ + init_arg_names = ("kernel", "order", "use_rscale", "m2l_translation") def __init__(self, kernel, order, use_rscale=None, @@ -78,6 +77,7 @@ class LocalExpansionBase(ExpansionBase): and self.m2l_translation == other.m2l_translation ) + @abstractmethod def translate_from(self, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, m2l_translation_classes_dependent_data=None): """Translate from a multipole or local expansion to a local expansion @@ -96,13 +96,11 @@ class LocalExpansionBase(ExpansionBase): expressions representing the expressions returned by :func:`~sumpy.expansion.m2l.M2LTranslationBase.translation_classes_dependent_data`. """ - raise NotImplementedError # {{{ line taylor class LineTaylorLocalExpansion(LocalExpansionBase): - def get_storage_index(self, k): return k @@ -150,12 +148,16 @@ class LineTaylorLocalExpansion(LocalExpansionBase): coeffs[self.get_storage_index(i)] / math.factorial(i) for i in self.get_coefficient_identifiers())) + def translate_from(self, src_expansion, src_coeff_exprs, src_rscale, + dvec, tgt_rscale, sac=None, m2l_translation_classes_dependent_data=None): + raise NotImplementedError + # }}} # {{{ volume taylor -class VolumeTaylorLocalExpansionBase(LocalExpansionBase): +class VolumeTaylorLocalExpansionBase(VolumeTaylorExpansionMixin, LocalExpansionBase): """ Coefficients represent derivative values of the kernel. """ @@ -463,14 +465,21 @@ class BiharmonicConformingVolumeTaylorLocalExpansion( # {{{ 2D Bessel-based-expansion class _FourierBesselLocalExpansion(LocalExpansionBase): - def __init__(self, kernel, order, use_rscale=None, - m2l_translation=None): + def __init__(self, + kernel, order, mpole_expn_class, + use_rscale=None, m2l_translation=None): if not m2l_translation: from sumpy.expansion.m2l import DefaultM2LTranslationClassFactory factory = DefaultM2LTranslationClassFactory() - m2l_translation = factory.get_m2l_translation_class(kernel, - self.__class__)() + m2l_translation = ( + factory.get_m2l_translation_class(kernel, self.__class__)()) + super().__init__(kernel, order, use_rscale, m2l_translation) + self.mpole_expn_class = mpole_expn_class + + @abstractmethod + def get_bessel_arg_scaling(self): + pass def get_storage_index(self, k): return self.order+k @@ -561,10 +570,10 @@ class H2DLocalExpansion(_FourierBesselLocalExpansion): assert (isinstance(kernel.get_base_kernel(), HelmholtzKernel) and kernel.dim == 2) - super().__init__(kernel, order, use_rscale, m2l_translation=m2l_translation) - from sumpy.expansion.multipole import H2DMultipoleExpansion - self.mpole_expn_class = H2DMultipoleExpansion + super().__init__(kernel, order, H2DMultipoleExpansion, + use_rscale=use_rscale, + m2l_translation=m2l_translation) def get_bessel_arg_scaling(self): return sym.Symbol(self.kernel.get_base_kernel().helmholtz_k_name) @@ -576,10 +585,10 @@ class Y2DLocalExpansion(_FourierBesselLocalExpansion): assert (isinstance(kernel.get_base_kernel(), YukawaKernel) and kernel.dim == 2) - super().__init__(kernel, order, use_rscale, m2l_translation=m2l_translation) - from sumpy.expansion.multipole import Y2DMultipoleExpansion - self.mpole_expn_class = Y2DMultipoleExpansion + super().__init__(kernel, order, Y2DMultipoleExpansion, + use_rscale=use_rscale, + m2l_translation=m2l_translation) def get_bessel_arg_scaling(self): return sym.I * sym.Symbol(self.kernel.get_base_kernel().yukawa_lambda_name) diff --git a/sumpy/expansion/m2l.py b/sumpy/expansion/m2l.py index e7c4793f..3d4b89e7 100644 --- a/sumpy/expansion/m2l.py +++ b/sumpy/expansion/m2l.py @@ -20,7 +20,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Tuple, Any +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Tuple import pymbolic import loopy as lp @@ -47,16 +48,16 @@ __doc__ = """ # {{{ M2L translation factory -class M2LTranslationClassFactoryBase: - """An interface +class M2LTranslationClassFactoryBase(ABC): + """ .. automethod:: get_m2l_translation_class """ + @abstractmethod def get_m2l_translation_class(self, base_kernel, local_expansion_class): """Returns a subclass of :class:`M2LTranslationBase` suitable for *base_kernel* and *local_expansion_class*. """ - raise NotImplementedError() class NonFFTM2LTranslationClassFactory(M2LTranslationClassFactoryBase): @@ -113,12 +114,12 @@ class DefaultM2LTranslationClassFactory(M2LTranslationClassFactoryBase): raise RuntimeError( f"Unknown local_expansion_class: {local_expansion_class}") - # }}} + # {{{ M2LTranslationBase -class M2LTranslationBase: +class M2LTranslationBase(ABC): """Base class for Multipole to Local Translation .. automethod:: translate @@ -133,10 +134,10 @@ class M2LTranslationBase: .. autoattribute:: use_preprocessing """ - use_fft = False - use_preprocessing = False + use_fft: ClassVar[bool] = False + use_preprocessing: ClassVar[bool] = False - def __setattr__(self): + def __setattr__(self, name, value): # These are intended to be stateless. raise AttributeError(f"{type(self)} is stateless and does not permit " "attribute modification.") @@ -144,9 +145,10 @@ class M2LTranslationBase: def __eq__(self, other): return type(self) is type(other) + @abstractmethod def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None): - raise NotImplementedError + pass def loopy_translate(self, tgt_expansion, src_expansion): raise NotImplementedError( @@ -190,6 +192,7 @@ class M2LTranslationBase: return translation_classes_dependent_data_loopy_knl(tgt_expansion, src_expansion, result_dtype) + @abstractmethod def preprocess_multipole_exprs(self, tgt_expansion, src_expansion, src_coeff_exprs, sac, src_rscale): """Return the preprocessed multipole expansion for an optimized M2L. @@ -203,7 +206,6 @@ class M2LTranslationBase: expansion coefficients with zeros added to make the M2L computation a circulant matvec. """ - raise NotImplementedError def preprocess_multipole_nexprs(self, tgt_expansion, src_expansion): """Return the number of expressions returned by @@ -217,6 +219,7 @@ class M2LTranslationBase: return self.translation_classes_dependent_ndata(tgt_expansion, src_expansion) + @abstractmethod def postprocess_local_exprs(self, tgt_expansion, src_expansion, m2l_result, src_rscale, tgt_rscale, sac): """Return postprocessed local expansion for an optimized M2L. @@ -227,7 +230,6 @@ class M2LTranslationBase: When FFT is turned on, the output expressions are assumed to have been transformed from Fourier space back to the original space by the caller. """ - raise NotImplementedError def postprocess_local_nexprs(self, tgt_expansion, src_expansion): """Return the number of expressions given as input to @@ -244,9 +246,9 @@ class M2LTranslationBase: def update_persistent_hash(self, key_hash, key_builder): key_hash.update(type(self).__name__.encode("utf8")) - # }}} M2LTranslationBase + # {{{ VolumeTaylorM2LTranslation class VolumeTaylorM2LTranslation(M2LTranslationBase): @@ -627,13 +629,13 @@ class VolumeTaylorM2LTranslation(M2LTranslationBase): fixed_parameters=fixed_parameters, ) - # }}} VolumeTaylorM2LTranslation + # {{{ VolumeTaylorM2LWithPreprocessedMultipoles class VolumeTaylorM2LWithPreprocessedMultipoles(VolumeTaylorM2LTranslation): - use_preprocessing = True + use_preprocessing: ClassVar[bool] = True def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None): @@ -690,13 +692,13 @@ class VolumeTaylorM2LWithPreprocessedMultipoles(VolumeTaylorM2LTranslation): lang_version=lp.MOST_RECENT_LANGUAGE_VERSION, ) - # }}} VolumeTaylorM2LWithPreprocessedMultipoles + # {{{ VolumeTaylorM2LWithFFT class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles): - use_fft = True + use_fft: ClassVar[bool] = True def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None): @@ -738,9 +740,9 @@ class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles): return super().postprocess_local_exprs(tgt_expansion, src_expansion, m2l_result, src_rscale, tgt_rscale, sac) - # }}} VolumeTaylorM2LWithFFT + # {{{ FourierBesselM2LTranslation class FourierBesselM2LTranslation(M2LTranslationBase): @@ -823,13 +825,13 @@ class FourierBesselM2LTranslation(M2LTranslationBase): def postprocess_local_nexprs(self, tgt_expansion, src_expansion): return 2*tgt_expansion.order + 1 - # }}} FourierBesselM2LTranslation + # {{{ FourierBesselM2LWithPreprocessedMultipoles class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation): - use_preprocessing = True + use_preprocessing: ClassVar[bool] = True def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None): @@ -846,8 +848,8 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation): return translated_coeffs def loopy_translate(self, tgt_expansion, src_expansion): - ncoeff_src = self.preprocess_multipole_nexprs(src_expansion) - ncoeff_tgt = self.postprocess_local_nexprs(src_expansion) + ncoeff_src = self.preprocess_multipole_nexprs(tgt_expansion, src_expansion) + ncoeff_tgt = self.postprocess_local_nexprs(tgt_expansion, src_expansion) icoeff_src = pymbolic.var("icoeff_src") icoeff_tgt = pymbolic.var("icoeff_tgt") @@ -857,7 +859,7 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation): src_coeffs = pymbolic.var("src_coeffs") translation_classes_dependent_data = pymbolic.var("data") - if self.use_fft_for_m2l: + if self.use_fft: expr = src_coeffs[icoeff_tgt] \ * translation_classes_dependent_data[icoeff_tgt] else: @@ -884,13 +886,13 @@ class FourierBesselM2LWithPreprocessedMultipoles(FourierBesselM2LTranslation): lang_version=lp.MOST_RECENT_LANGUAGE_VERSION, ) - # }}} FourierBesselM2LWithPreprocessedMultipoles + # {{{ FourierBesselM2LWithFFT class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles): - use_fft = True + use_fft: ClassVar[bool] = True def __init__(self): # FIXME: expansion with FFT is correct symbolically and can be verified @@ -899,8 +901,7 @@ class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles): # instability but gives rscale as a possible solution. Sumpy's rscale # choice is slightly different from Greengard and Rokhlin and that # might be the reason for this numerical issue. - raise ValueError("Bessel based expansions with FFT is not fully " - "supported yet.") + raise ValueError("Bessel based expansions with FFT are not supported yet.") def translate(self, tgt_expansion, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, translation_classes_dependent_data=None): @@ -956,9 +957,9 @@ class FourierBesselM2LWithFFT(FourierBesselM2LWithPreprocessedMultipoles): return super().postprocess_local_exprs(tgt_expansion, src_expansion, m2l_result, src_rscale, tgt_rscale, sac) - # }}} FourierBesselM2LWithFFT + # {{{ translation_classes_dependent_data_loopy_knl def translation_classes_dependent_data_loopy_knl(tgt_expansion, src_expansion, diff --git a/sumpy/expansion/multipole.py b/sumpy/expansion/multipole.py index 1924e6c7..a344fe0a 100644 --- a/sumpy/expansion/multipole.py +++ b/sumpy/expansion/multipole.py @@ -21,10 +21,14 @@ THE SOFTWARE. """ import math +from abc import abstractmethod import sumpy.symbolic as sym from sumpy.expansion import ( - ExpansionBase, VolumeTaylorExpansion, LinearPDEConformingVolumeTaylorExpansion) + ExpansionBase, + VolumeTaylorExpansion, + VolumeTaylorExpansionMixin, + LinearPDEConformingVolumeTaylorExpansion) from sumpy.tools import mi_set_axis, add_to_sac, mi_power, mi_factorial import logging @@ -46,7 +50,8 @@ class MultipoleExpansionBase(ExpansionBase): # {{{ volume taylor -class VolumeTaylorMultipoleExpansionBase(MultipoleExpansionBase): +class VolumeTaylorMultipoleExpansionBase( + VolumeTaylorExpansionMixin, MultipoleExpansionBase): """ Coefficients represent the terms in front of the kernel derivatives. """ @@ -396,6 +401,10 @@ class BiharmonicConformingVolumeTaylorMultipoleExpansion( # {{{ 2D Hankel-based expansions class _HankelBased2DMultipoleExpansion(MultipoleExpansionBase): + @abstractmethod + def get_bessel_arg_scaling(self): + return + def get_storage_index(self, k): return self.order+k -- GitLab