From 661f58514353a360d18a4548d21331a7f3bab568 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Wed, 1 Feb 2017 23:10:43 -0600 Subject: [PATCH] CSE: Don't pull things out of derivatives, Subs (closes #14). --- sumpy/codegen.py | 22 ---------------------- sumpy/cse.py | 25 ++++++++++++++++++------- test/test_cse.py | 30 +++++++++++++++++------------- 3 files changed, 35 insertions(+), 42 deletions(-) diff --git a/sumpy/codegen.py b/sumpy/codegen.py index 01385642..8eef9892 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -63,24 +63,6 @@ _SPECIAL_FUNCTION_NAMES = frozenset(dir(sym.functions)) class SympyToPymbolicMapper(SympyToPymbolicMapperBase): - def __init__(self, assignments): - self.assignments = dict( - (sym.Symbol(name), value) for name, value in assignments) - self.derivative_cse_names = set() - - def map_Derivative(self, expr): # noqa - # Sympy has picked up the habit of picking arguments out of derivatives - # and pronounce them common subexpressions. Me no like. Undo it, so - # that the bessel substitutor further down can do its job. - - if expr.expr.is_Symbol: - # These will get removed, because loopy wont' be able to deal - # with them--they contain undefined placeholder symbols. - self.derivative_cse_names.add(expr.expr.name) - - return prim.Derivative(self.rec( - expr.expr.subs(self.assignments)), - tuple(v.name for v in expr.variables)) def not_supported(self, expr): if isinstance(expr, int): @@ -676,10 +658,6 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], # convert from sympy sympy_conv = SympyToPymbolicMapper(assignments) assignments = [(name, sympy_conv(expr)) for name, expr in assignments] - assignments = [ - (name, expr) for name, expr in assignments - if name not in sympy_conv.derivative_cse_names - ] assignments = kill_trivial_assignments(assignments, retain_names) diff --git a/sumpy/cse.py b/sumpy/cse.py index a073983a..6a0270e7 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -66,7 +66,8 @@ DAMAGE. # }}} -from sumpy.symbolic import Basic, Mul, Add, Pow, Symbol, _coeff_isneg +from sumpy.symbolic import ( + Basic, Mul, Add, Pow, Symbol, _coeff_isneg, Derivative, Subs) from sympy.core.compatibility import iterable from sympy.utilities.iterables import numbered_symbols @@ -81,6 +82,10 @@ Common subexpression elimination """ +# Don't CSE child nodes of these classes. +CSE_NO_DESCEND_CLASSES = (Derivative, Subs) + + # {{{ cse pre/postprocessing def preprocess_for_cse(expr, optimizations): @@ -352,6 +357,9 @@ def opt_cse(exprs): if expr.is_Atom: return + if isinstance(expr, CSE_NO_DESCEND_CLASSES): + return + if iterable(expr): for item in expr: find_opts(item) @@ -443,7 +451,10 @@ def tree_cse(exprs, symbols, opt_subs=None): if expr in opt_subs: expr = opt_subs[expr] - args = expr.args + if isinstance(expr, CSE_NO_DESCEND_CLASSES): + args = () + else: + args = expr.args for arg in args: find_repeated(arg) @@ -481,11 +492,11 @@ def tree_cse(exprs, symbols, opt_subs=None): if expr in opt_subs: expr = opt_subs[expr] - new_args = tuple(rebuild(arg) for arg in expr.args) - if isinstance(expr, Unevaluated) or new_args != expr.args: - new_expr = expr.func(*new_args) - else: - new_expr = expr + new_expr = expr + if not isinstance(expr, CSE_NO_DESCEND_CLASSES): + new_args = tuple(rebuild(arg) for arg in expr.args) + if isinstance(expr, Unevaluated) or new_args != expr.args: + new_expr = expr.func(*new_args) if orig_expr in to_eliminate: try: diff --git a/test/test_cse.py b/test/test_cse.py index c9f0264e..4f957aaa 100644 --- a/test/test_cse.py +++ b/test/test_cse.py @@ -189,7 +189,7 @@ def test_multiple_expressions(): rsubsts, _ = cse(reversed(l)) assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] - assert reduced == [f(x1,x2), x1, x2] + assert reduced == [f(x1, x2), x1, x2] l = [w*y + w + x + y + z, w*x*y] assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) @@ -202,27 +202,31 @@ def test_issue_4203(): assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) -@sympyonly -def test_dont_cse_tuples(): +def test_dont_cse_subs(): from sumpy.symbolic import Subs f = Function("f") g = Function("g") name_val, (expr,) = cse( - Subs(f(x, y), (x, y), (0, 1)) - + Subs(g(x, y), (x, y), (0, 1))) + Subs(f(x, y), (x, y), (0, x + y)) + + Subs(g(x, y), (x, y), (0, x + y))) assert name_val == [] - assert expr == (Subs(f(x, y), (x, y), (0, 1)) - + Subs(g(x, y), (x, y), (0, 1))) + assert expr == Subs(f(x, y), (x, y), (0, x + y)) + \ + Subs(g(x, y), (x, y), (0, x + y)) - name_val, (expr,) = cse( - Subs(f(x, y), (x, y), (0, x + y)) - + Subs(g(x, y), (x, y), (0, x + y))) - assert name_val == [(x0, x + y)] - assert expr == Subs(f(x, y), (x, y), (0, x0)) + \ - Subs(g(x, y), (x, y), (0, x0)) +def test_dont_cse_derivative(): + from sumpy.symbolic import Derivative + f = Function("f") + + # FIXME + deriv = Derivative(f(x+y), (x,)) if USE_SYMENGINE else Derivative(f(x+y), x) + + name_val, (expr,) = cse(x + y + deriv) + + assert name_val == [] + assert expr == x + y + deriv def test_pow_invpow(): -- GitLab