From 50221794054d5d42fe42f7e907cc772e07d68ac8 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 7 Jun 2013 19:44:01 -0400
Subject: [PATCH] Test, improve sympy round-trip translation.

---
 pymbolic/sympy_interface.py | 26 +++++++++++
 test/test_pymbolic.py       | 87 +++++++++++++++++++++++++++----------
 2 files changed, 89 insertions(+), 24 deletions(-)

diff --git a/pymbolic/sympy_interface.py b/pymbolic/sympy_interface.py
index f66cdef..7df6149 100644
--- a/pymbolic/sympy_interface.py
+++ b/pymbolic/sympy_interface.py
@@ -66,6 +66,8 @@ def make_cse(arg, prefix=None):
     return result
 
 
+# {{{ sympy -> pymbolic
+
 class SympyToPymbolicMapper(SympyMapper):
     def map_Symbol(self, expr):
         return prim.Variable(expr.name)
@@ -73,6 +75,9 @@ class SympyToPymbolicMapper(SympyMapper):
     def map_ImaginaryUnit(self, expr):
         return 1j
 
+    def map_Float(self, expr):
+        return float(expr)
+
     def map_Pi(self, expr):
         return float(expr)
 
@@ -116,11 +121,18 @@ class SympyToPymbolicMapper(SympyMapper):
         else:
             return SympyMapper.not_supported(self, expr)
 
+# }}}
+
+
+# {{{ pymbolic -> sympy
 
 class PymbolicToSympyMapper(EvaluationMapper):
     def map_variable(self, expr):
         return sp.Symbol(expr.name)
 
+    def map_constant(self, expr):
+        return sp.sympify(expr)
+
     def map_call(self, expr):
         if isinstance(expr.function, prim.Variable):
             func_name = expr.function.name
@@ -139,3 +151,17 @@ class PymbolicToSympyMapper(EvaluationMapper):
         else:
             raise RuntimeError("do not know how to translate '%s' to sympy"
                     % expr)
+
+    def map_substitution(self, expr):
+        return sp.Subs(self.rec(expr.child),
+                tuple(sp.Symbol(v) for v in expr.variables),
+                tuple(self.rec(v) for v in expr.values),
+                )
+
+    def map_derivative(self, expr):
+        return sp.Derivative(self.rec(expr.child),
+                *[sp.Symbol(v) for v in expr.variables])
+
+# }}}
+
+# vim: fdm=marker
diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py
index 33173ba..3e7a1aa 100644
--- a/test/test_pymbolic.py
+++ b/test/test_pymbolic.py
@@ -43,6 +43,60 @@ def test_substitute():
     assert evaluate(substitute(u, {xmin: 25})) == 630
 
 
+def test_no_comparison():
+    from pymbolic import parse
+
+    x = parse("17+3*x")
+    y = parse("12-5*y")
+
+    def expect_typeerror(f):
+        try:
+            f()
+        except TypeError:
+            pass
+        else:
+            assert False
+
+    expect_typeerror(lambda: x < y)
+    expect_typeerror(lambda: x <= y)
+    expect_typeerror(lambda: x > y)
+    expect_typeerror(lambda: x >= y)
+
+
+def test_structure_preservation():
+    x = prim.Sum((5, 7))
+    from pymbolic.mapper import IdentityMapper
+    x2 = IdentityMapper()(x)
+    assert x == x2
+
+
+def test_sympy_interaction():
+    pytest.importorskip("sympy")
+
+    import sympy as sp
+
+    x, y = sp.symbols("x y")
+    f = sp.symbols("f")
+
+    s1_expr = 1/f(x/sp.sqrt(x**2+y**2)).diff(x, 5)
+
+    from pymbolic.sympy_interface import (
+            SympyToPymbolicMapper,
+            PymbolicToSympyMapper)
+    s2p = SympyToPymbolicMapper()
+    p2s = PymbolicToSympyMapper()
+
+    p1_expr = s2p(s1_expr)
+
+    s2_expr = p2s(p1_expr)
+    assert s1_expr == s2_expr
+
+    p2_expr = s2p(s2_expr)
+    assert p1_expr == p2_expr
+
+
+# {{{ fft
+
 def test_fft_with_floats():
     numpy = pytest.importorskip("numpy")
     import numpy.linalg as la
@@ -97,6 +151,8 @@ def test_fft():
     for i, line in enumerate(code):
         print("result[%d] = %s" % (i, line))
 
+# }}}
+
 
 def test_sparse_multiply():
     numpy = pytest.importorskip("numpy")
@@ -117,25 +173,7 @@ def test_sparse_multiply():
     assert la.norm(mat_vec-mat_vec_2) < 1e-14
 
 
-def test_no_comparison():
-    from pymbolic import parse
-
-    x = parse("17+3*x")
-    y = parse("12-5*y")
-
-    def expect_typeerror(f):
-        try:
-            f()
-        except TypeError:
-            pass
-        else:
-            assert False
-
-    expect_typeerror(lambda: x < y)
-    expect_typeerror(lambda: x <= y)
-    expect_typeerror(lambda: x > y)
-    expect_typeerror(lambda: x >= y)
-
+# {{{ parser
 
 def test_parser():
     from pymbolic import parse
@@ -172,13 +210,10 @@ def test_parser():
     assert parse("f((x,),z)") == f((x,), z)
     assert parse("f(x,(y,z),z)") == f(x, (y, z), z)
 
+# }}}
 
-def test_structure_preservation():
-    x = prim.Sum((5, 7))
-    from pymbolic.mapper import IdentityMapper
-    x2 = IdentityMapper()(x)
-    assert x == x2
 
+# {{{ geometric algebra
 
 @pytest.mark.parametrize("dims", [2, 3, 4, 5])
 # START_GA_TEST
@@ -279,6 +314,8 @@ def test_geometric_algebra(dims):
         assert a.x(b*c) .close_to(a.x(b)*c + b*a.x(c))
 # END_GA_TEST
 
+# }}}
+
 
 if __name__ == "__main__":
     import sys
@@ -287,3 +324,5 @@ if __name__ == "__main__":
     else:
         from py.test.cmdline import main
         main([__file__])
+
+# vim: fdm=marker
-- 
GitLab