From 50221794054d5d42fe42f7e907cc772e07d68ac8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Fri, 7 Jun 2013 19:44:01 -0400 Subject: [PATCH] Test, improve sympy round-trip translation. --- pymbolic/sympy_interface.py | 26 +++++++++++ test/test_pymbolic.py | 87 +++++++++++++++++++++++++++---------- 2 files changed, 89 insertions(+), 24 deletions(-) diff --git a/pymbolic/sympy_interface.py b/pymbolic/sympy_interface.py index f66cdef..7df6149 100644 --- a/pymbolic/sympy_interface.py +++ b/pymbolic/sympy_interface.py @@ -66,6 +66,8 @@ def make_cse(arg, prefix=None): return result +# {{{ sympy -> pymbolic + class SympyToPymbolicMapper(SympyMapper): def map_Symbol(self, expr): return prim.Variable(expr.name) @@ -73,6 +75,9 @@ class SympyToPymbolicMapper(SympyMapper): def map_ImaginaryUnit(self, expr): return 1j + def map_Float(self, expr): + return float(expr) + def map_Pi(self, expr): return float(expr) @@ -116,11 +121,18 @@ class SympyToPymbolicMapper(SympyMapper): else: return SympyMapper.not_supported(self, expr) +# }}} + + +# {{{ pymbolic -> sympy class PymbolicToSympyMapper(EvaluationMapper): def map_variable(self, expr): return sp.Symbol(expr.name) + def map_constant(self, expr): + return sp.sympify(expr) + def map_call(self, expr): if isinstance(expr.function, prim.Variable): func_name = expr.function.name @@ -139,3 +151,17 @@ class PymbolicToSympyMapper(EvaluationMapper): else: raise RuntimeError("do not know how to translate '%s' to sympy" % expr) + + def map_substitution(self, expr): + return sp.Subs(self.rec(expr.child), + tuple(sp.Symbol(v) for v in expr.variables), + tuple(self.rec(v) for v in expr.values), + ) + + def map_derivative(self, expr): + return sp.Derivative(self.rec(expr.child), + *[sp.Symbol(v) for v in expr.variables]) + +# }}} + +# vim: fdm=marker diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 33173ba..3e7a1aa 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -43,6 +43,60 @@ def test_substitute(): assert evaluate(substitute(u, {xmin: 25})) == 630 +def test_no_comparison(): + from pymbolic import parse + + x = parse("17+3*x") + y = parse("12-5*y") + + def expect_typeerror(f): + try: + f() + except TypeError: + pass + else: + assert False + + expect_typeerror(lambda: x < y) + expect_typeerror(lambda: x <= y) + expect_typeerror(lambda: x > y) + expect_typeerror(lambda: x >= y) + + +def test_structure_preservation(): + x = prim.Sum((5, 7)) + from pymbolic.mapper import IdentityMapper + x2 = IdentityMapper()(x) + assert x == x2 + + +def test_sympy_interaction(): + pytest.importorskip("sympy") + + import sympy as sp + + x, y = sp.symbols("x y") + f = sp.symbols("f") + + s1_expr = 1/f(x/sp.sqrt(x**2+y**2)).diff(x, 5) + + from pymbolic.sympy_interface import ( + SympyToPymbolicMapper, + PymbolicToSympyMapper) + s2p = SympyToPymbolicMapper() + p2s = PymbolicToSympyMapper() + + p1_expr = s2p(s1_expr) + + s2_expr = p2s(p1_expr) + assert s1_expr == s2_expr + + p2_expr = s2p(s2_expr) + assert p1_expr == p2_expr + + +# {{{ fft + def test_fft_with_floats(): numpy = pytest.importorskip("numpy") import numpy.linalg as la @@ -97,6 +151,8 @@ def test_fft(): for i, line in enumerate(code): print("result[%d] = %s" % (i, line)) +# }}} + def test_sparse_multiply(): numpy = pytest.importorskip("numpy") @@ -117,25 +173,7 @@ def test_sparse_multiply(): assert la.norm(mat_vec-mat_vec_2) < 1e-14 -def test_no_comparison(): - from pymbolic import parse - - x = parse("17+3*x") - y = parse("12-5*y") - - def expect_typeerror(f): - try: - f() - except TypeError: - pass - else: - assert False - - expect_typeerror(lambda: x < y) - expect_typeerror(lambda: x <= y) - expect_typeerror(lambda: x > y) - expect_typeerror(lambda: x >= y) - +# {{{ parser def test_parser(): from pymbolic import parse @@ -172,13 +210,10 @@ def test_parser(): assert parse("f((x,),z)") == f((x,), z) assert parse("f(x,(y,z),z)") == f(x, (y, z), z) +# }}} -def test_structure_preservation(): - x = prim.Sum((5, 7)) - from pymbolic.mapper import IdentityMapper - x2 = IdentityMapper()(x) - assert x == x2 +# {{{ geometric algebra @pytest.mark.parametrize("dims", [2, 3, 4, 5]) # START_GA_TEST @@ -279,6 +314,8 @@ def test_geometric_algebra(dims): assert a.x(b*c) .close_to(a.x(b)*c + b*a.x(c)) # END_GA_TEST +# }}} + if __name__ == "__main__": import sys @@ -287,3 +324,5 @@ if __name__ == "__main__": else: from py.test.cmdline import main main([__file__]) + +# vim: fdm=marker -- GitLab