From 95b2d3554af942e2a31a552844932cc7f3aa74cc Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 23 Feb 2017 18:34:07 -0600 Subject: [PATCH] Fix LineTaylorLocalExpansion to work with SymEngine (still needed: tests in sumpy, but that's a long-standing bug: see #2). --- sumpy/expansion/local.py | 25 ++++++++++++-------- sumpy/tools.py | 49 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py index bdcd9043..f9ec106a 100644 --- a/sumpy/expansion/local.py +++ b/sumpy/expansion/local.py @@ -61,24 +61,29 @@ class LineTaylorLocalExpansion(LocalExpansionBase): raise RuntimeError("cannot use line-Taylor expansions in a setting " "where the center-target vector is not known at coefficient " "formation") - avec_line = avec + sym.Symbol("tau")*bvec + + tau = sym.Symbol("tau") + + avec_line = avec + tau*bvec line_kernel = self.kernel.get_expression(avec_line) - return [ - self.kernel.postprocess_at_target( - self.kernel.postprocess_at_source( - line_kernel.diff("tau", i), - avec), - bvec) - .subs("tau", 0) + from sumpy.tools import MiDerivativeTaker, my_syntactic_subs + deriv_taker = MiDerivativeTaker(line_kernel, (tau,)) + + return [my_syntactic_subs( + self.kernel.postprocess_at_target( + self.kernel.postprocess_at_source( + deriv_taker.diff(i), + avec), bvec), + {tau: 0}) for i in self.get_coefficient_identifiers()] def evaluate(self, coeffs, bvec): from pytools import factorial - return sum( + return sym.Add(*( coeffs[self.get_storage_index(i)] / factorial(i) - for i in self.get_coefficient_identifiers()) + for i in self.get_coefficient_identifiers())) # }}} diff --git a/sumpy/tools.py b/sumpy/tools.py index 51cc52be..af125566 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -385,4 +385,53 @@ class KernelCacheWrapper(object): return knl + +def my_syntactic_subs(expr, subst_dict): + # Workaround for differing substitution semantics between sympy and symengine. + # FIXME: This is a hack. + from sumpy.symbolic import Basic, Subs, Derivative, USE_SYMENGINE + + if not isinstance(expr, Basic): + return expr + + elif expr.is_Symbol: + return subst_dict.get(expr, expr) + + elif isinstance(expr, Subs): + new_point = tuple(my_syntactic_subs(p, subst_dict) for p in expr.point) + + import six + new_subst_dict = dict( + (var, subs) for var, subs in six.iteritems(subst_dict) + if var not in expr.variables) + + new_expr = my_syntactic_subs(expr.expr, new_subst_dict) + + if new_point != expr.point or new_expr != expr.expr: + return Subs(new_expr, expr.variables, new_point) + + return expr + + elif isinstance(expr, Derivative): + new_expr = my_syntactic_subs(expr.expr, subst_dict) + new_variables = my_syntactic_subs(expr.variables, subst_dict) + + if new_expr != expr.expr or any(new_var != var for new_var, var in + zip(new_variables, expr.variables)): + # FIXME in SymEngine + if USE_SYMENGINE: + return Derivative(new_expr, new_variables) + else: + return Derivative(new_expr, *new_variables) + + return expr + + else: + new_args = tuple(my_syntactic_subs(arg, subst_dict) for arg in expr.args) + if any(new_arg != arg for arg, new_arg in zip(expr.args, new_args)): + return expr.func(*new_args) + else: + return expr + + # vim: fdm=marker -- GitLab