diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py index 6c1fd28a3d9bb6e4fbb3484d9b303a7f446b4711..e321993367925b7be9fa56a48473ab9bd4e876ff 100644 --- a/sumpy/expansion/local.py +++ b/sumpy/expansion/local.py @@ -91,11 +91,11 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): """ def coefficients_from_source(self, avec, bvec): - from sumpy.tools import mi_derivative + from sumpy.tools import MiDerivativeTaker ppkernel = self.kernel.postprocess_at_source( self.kernel.get_expression(avec), avec) - return [mi_derivative(ppkernel, avec, mi) - for mi in self.get_coefficient_identifiers()] + taker = MiDerivativeTaker(ppkernel, avec) + return [taker.diff(mi) for mi in self.get_coefficient_identifiers()] def evaluate(self, coeffs, bvec): from sumpy.tools import mi_power, mi_factorial @@ -115,10 +115,10 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): type(self).__name__, self.order)) - from sumpy.tools import mi_derivative + from sumpy.tools import MiDerivativeTaker expr = src_expansion.evaluate(src_coeff_exprs, dvec) - result = [mi_derivative(expr, dvec, mi) - for mi in self.get_coefficient_identifiers()] + taker = MiDerivativeTaker(expr, dvec) + result = [taker.diff(mi) for mi in self.get_coefficient_identifiers()] logger.info("building translation operator: done") return result diff --git a/sumpy/expansion/multipole.py b/sumpy/expansion/multipole.py index f579bd304d404a862196aa46eb136c5d49194ebc..79412e17a248e2d647ddbfe5a7a410dfd6bc72d0 100644 --- a/sumpy/expansion/multipole.py +++ b/sumpy/expansion/multipole.py @@ -92,9 +92,10 @@ class VolumeTaylorMultipoleExpansionBase(MultipoleExpansionBase): ppkernel = self.kernel.postprocess_at_target( self.kernel.get_expression(bvec), bvec) - from sumpy.tools import mi_derivative + from sumpy.tools import MiDerivativeTaker + taker = MiDerivativeTaker(ppkernel, bvec) result = sum( - coeff * mi_derivative(ppkernel, bvec, mi) + coeff * taker.diff(mi) for coeff, mi in zip(coeffs, self.get_coefficient_identifiers())) return result diff --git a/sumpy/tools.py b/sumpy/tools.py index d2d3a711e08cdf183b5840c60a0199a95e585281..88ff5241799c210a522e8b54935ff0d1ae9dde47 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -54,10 +54,33 @@ def mi_power(vector, mi): return result -def mi_derivative(expr, vector, mi): - for mi_i, vec_i in zip(mi, vector): - expr = expr.diff(vec_i, mi_i) - return expr +class MiDerivativeTaker(object): + + def __init__(self, expr, var_list): + self.var_list = var_list + empty_mi = (0,) * len(var_list) + self.cache_by_mi = {empty_mi: expr} + + def mi_dist(self, a, b): + return np.array(a, dtype=int) - np.array(b, dtype=int) + + def diff(self, mi): + closest_mi = min( + (other_mi + for other_mi in self.cache_by_mi.keys() + if (np.array(mi) >= np.array(other_mi)).all()), + key=lambda other_mi: sum(self.mi_dist(mi, other_mi))) + + expr = self.cache_by_mi[closest_mi] + current_mi = np.array(closest_mi, dtype=int) + for idx, (mi_i, vec_i) in enumerate( + zip(self.mi_dist(mi, closest_mi), self.var_list)): + for i in range(1, 1 + mi_i): + current_mi[idx] += 1 + expr = expr.diff(vec_i) + self.cache_by_mi[tuple(current_mi)] = expr + + return expr # }}}