diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py index 413d35ab86195420690b28f09138c781d29b730a..4f0a3a01482af78dacbfe0635ce45249a610f930 100644 --- a/sumpy/expansion/__init__.py +++ b/sumpy/expansion/__init__.py @@ -162,6 +162,8 @@ class ExpansionBase(object): class ExpansionTermsWrangler(object): + init_arg_names = ("order", "dim") + def __init__(self, order, dim): self.order = order self.dim = dim @@ -190,14 +192,18 @@ class ExpansionTermsWrangler(object): return res def copy(self, **kwargs): - order = kwargs.pop('order', self.order) - dim = kwargs.pop('dim', self.dim) + new_kwargs = dict( + (name, getattr(self, name)) + for name in self.init_arg_names) + + new_kwargs["order"] = kwargs.pop("order", self.order) + new_kwargs["dim"] = kwargs.pop("dim", self.dim) if kwargs: raise TypeError("unexpected keyword arguments '%s'" % ", ".join(kwargs)) - return type(self)(order, dim) + return type(self)(**new_kwargs) class FullExpansionTermsWrangler(ExpansionTermsWrangler): @@ -255,6 +261,8 @@ class LinearRecurrenceBasedExpansionTermsWrangler(ExpansionTermsWrangler): .. automethod:: get_reduced_coeffs """ + init_arg_names = ("order", "dim", "deriv_multiplier") + def __init__(self, order, dim, deriv_multiplier): r""" :param order: order of the expansion @@ -412,7 +420,9 @@ class LinearRecurrenceBasedExpansionTermsWrangler(ExpansionTermsWrangler): \sum_{\nu,c_\nu\in \text{pde\_dict}} \frac{c_\nu\cdot \alpha_\nu} - {\text{deriv\_multiplier}^{\sum \text{mi}}} = 0, + {\text{deriv\_multiplier}^{ + \sum \text{mi} + }} = 0, where :math:`\mathbf\alpha` is a coefficient vector. @@ -443,6 +453,8 @@ class LinearRecurrenceBasedExpansionTermsWrangler(ExpansionTermsWrangler): class LaplaceExpansionTermsWrangler(LinearRecurrenceBasedExpansionTermsWrangler): + init_arg_names = ("order", "dim") + def __init__(self, order, dim): super(LaplaceExpansionTermsWrangler, self).__init__(order, dim, 1) @@ -466,7 +478,11 @@ class LaplaceExpansionTermsWrangler(LinearRecurrenceBasedExpansionTermsWrangler) class HelmholtzExpansionTermsWrangler(LinearRecurrenceBasedExpansionTermsWrangler): + init_arg_names = ("order", "dim", "helmholtz_k_name") + def __init__(self, order, dim, helmholtz_k_name): + self.helmholtz_k_name = helmholtz_k_name + multiplier = sym.Symbol(helmholtz_k_name) super(HelmholtzExpansionTermsWrangler, self).__init__(order, dim, multiplier)