From 661f58514353a360d18a4548d21331a7f3bab568 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Wed, 1 Feb 2017 23:10:43 -0600
Subject: [PATCH] CSE: Don't pull things out of derivatives, Subs (closes #14).

---
 sumpy/codegen.py | 22 ----------------------
 sumpy/cse.py     | 25 ++++++++++++++++++-------
 test/test_cse.py | 30 +++++++++++++++++-------------
 3 files changed, 35 insertions(+), 42 deletions(-)

diff --git a/sumpy/codegen.py b/sumpy/codegen.py
index 01385642..8eef9892 100644
--- a/sumpy/codegen.py
+++ b/sumpy/codegen.py
@@ -63,24 +63,6 @@ _SPECIAL_FUNCTION_NAMES = frozenset(dir(sym.functions))
 
 
 class SympyToPymbolicMapper(SympyToPymbolicMapperBase):
-    def __init__(self, assignments):
-        self.assignments = dict(
-            (sym.Symbol(name), value) for name, value in assignments)
-        self.derivative_cse_names = set()
-
-    def map_Derivative(self, expr):  # noqa
-        # Sympy has picked up the habit of picking arguments out of derivatives
-        # and pronounce them common subexpressions. Me no like. Undo it, so
-        # that the bessel substitutor further down can do its job.
-
-        if expr.expr.is_Symbol:
-            # These will get removed, because loopy wont' be able to deal
-            # with them--they contain undefined placeholder symbols.
-            self.derivative_cse_names.add(expr.expr.name)
-
-        return prim.Derivative(self.rec(
-            expr.expr.subs(self.assignments)),
-            tuple(v.name for v in expr.variables))
 
     def not_supported(self, expr):
         if isinstance(expr, int):
@@ -676,10 +658,6 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[],
     # convert from sympy
     sympy_conv = SympyToPymbolicMapper(assignments)
     assignments = [(name, sympy_conv(expr)) for name, expr in assignments]
-    assignments = [
-            (name, expr) for name, expr in assignments
-            if name not in sympy_conv.derivative_cse_names
-            ]
 
     assignments = kill_trivial_assignments(assignments, retain_names)
 
diff --git a/sumpy/cse.py b/sumpy/cse.py
index a073983a..6a0270e7 100644
--- a/sumpy/cse.py
+++ b/sumpy/cse.py
@@ -66,7 +66,8 @@ DAMAGE.
 
 # }}}
 
-from sumpy.symbolic import Basic, Mul, Add, Pow, Symbol, _coeff_isneg
+from sumpy.symbolic import (
+    Basic, Mul, Add, Pow, Symbol, _coeff_isneg, Derivative, Subs)
 from sympy.core.compatibility import iterable
 from sympy.utilities.iterables import numbered_symbols
 
@@ -81,6 +82,10 @@ Common subexpression elimination
 """
 
 
+# Don't CSE child nodes of these classes.
+CSE_NO_DESCEND_CLASSES = (Derivative, Subs)
+
+
 # {{{ cse pre/postprocessing
 
 def preprocess_for_cse(expr, optimizations):
@@ -352,6 +357,9 @@ def opt_cse(exprs):
         if expr.is_Atom:
             return
 
+        if isinstance(expr, CSE_NO_DESCEND_CLASSES):
+            return
+
         if iterable(expr):
             for item in expr:
                 find_opts(item)
@@ -443,7 +451,10 @@ def tree_cse(exprs, symbols, opt_subs=None):
             if expr in opt_subs:
                 expr = opt_subs[expr]
 
-            args = expr.args
+            if isinstance(expr, CSE_NO_DESCEND_CLASSES):
+                args = ()
+            else:
+                args = expr.args
 
         for arg in args:
             find_repeated(arg)
@@ -481,11 +492,11 @@ def tree_cse(exprs, symbols, opt_subs=None):
         if expr in opt_subs:
             expr = opt_subs[expr]
 
-        new_args = tuple(rebuild(arg) for arg in expr.args)
-        if isinstance(expr, Unevaluated) or new_args != expr.args:
-            new_expr = expr.func(*new_args)
-        else:
-            new_expr = expr
+        new_expr = expr
+        if not isinstance(expr, CSE_NO_DESCEND_CLASSES):
+            new_args = tuple(rebuild(arg) for arg in expr.args)
+            if isinstance(expr, Unevaluated) or new_args != expr.args:
+                new_expr = expr.func(*new_args)
 
         if orig_expr in to_eliminate:
             try:
diff --git a/test/test_cse.py b/test/test_cse.py
index c9f0264e..4f957aaa 100644
--- a/test/test_cse.py
+++ b/test/test_cse.py
@@ -189,7 +189,7 @@ def test_multiple_expressions():
     rsubsts, _ = cse(reversed(l))
     assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
     assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
-    assert reduced == [f(x1,x2), x1, x2]
+    assert reduced == [f(x1, x2), x1, x2]
     l = [w*y + w + x + y + z, w*x*y]
     assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
     assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
@@ -202,27 +202,31 @@ def test_issue_4203():
     assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
 
 
-@sympyonly
-def test_dont_cse_tuples():
+def test_dont_cse_subs():
     from sumpy.symbolic import Subs
     f = Function("f")
     g = Function("g")
 
     name_val, (expr,) = cse(
-        Subs(f(x, y), (x, y), (0, 1))
-        + Subs(g(x, y), (x, y), (0, 1)))
+        Subs(f(x, y), (x, y), (0, x + y))
+        + Subs(g(x, y), (x, y), (0, x + y)))
 
     assert name_val == []
-    assert expr == (Subs(f(x, y), (x, y), (0, 1))
-            + Subs(g(x, y), (x, y), (0, 1)))
+    assert expr == Subs(f(x, y), (x, y), (0, x + y)) + \
+        Subs(g(x, y), (x, y), (0, x + y))
 
-    name_val, (expr,) = cse(
-        Subs(f(x, y), (x, y), (0, x + y))
-        + Subs(g(x, y), (x, y), (0, x + y)))
 
-    assert name_val == [(x0, x + y)]
-    assert expr == Subs(f(x, y), (x, y), (0, x0)) + \
-        Subs(g(x, y), (x, y), (0, x0))
+def test_dont_cse_derivative():
+    from sumpy.symbolic import Derivative
+    f = Function("f")
+
+    # FIXME
+    deriv = Derivative(f(x+y), (x,)) if USE_SYMENGINE else Derivative(f(x+y), x)
+
+    name_val, (expr,) = cse(x + y + deriv)
+
+    assert name_val == []
+    assert expr == x + y + deriv
 
 
 def test_pow_invpow():
-- 
GitLab