diff --git a/pymbolic/interop/sympy.py b/pymbolic/interop/sympy.py index b02d5e8b692040abb6e46cd8fecde95d4ad2d035..2ccb9fe4072b83c7978572ad31f1e636824ed0e3 100644 --- a/pymbolic/interop/sympy.py +++ b/pymbolic/interop/sympy.py @@ -28,6 +28,9 @@ THE SOFTWARE. from pymbolic.interop.common import ( SympyLikeToPymbolicMapper, PymbolicToSympyLikeMapper) +import pymbolic.primitives as prim +from functools import partial + import sympy @@ -60,6 +63,29 @@ class SympyToPymbolicMapper(SympyLikeToPymbolicMapper): def map_long(self, expr): return long(expr) # noqa + def map_Piecewise(self, expr): # noqa + # We only handle piecewises with 2 arguments! + assert len(expr.args) == 2 + # We only handle if/else cases + assert expr.args[0][1] == sympy.Not(expr.args[1][1]) + then = self.rec(expr.args[0][0]) + else_ = self.rec(expr.args[1][0]) + cond = self.rec(expr.args[0][1]) + return prim.If(cond, then, else_) + + def _comparison_operator(self, expr, operator=None): + left = self.rec(expr.args[0]) + right = self.rec(expr.args[1]) + return prim.Comparison(left, operator, right) + + map_Equality = partial(_comparison_operator, operator="==") + map_Unequality = partial(_comparison_operator, operator="!=") + map_GreaterThan = partial(_comparison_operator, operator=">=") + map_LessThan = partial(_comparison_operator, operator="<=") + map_StrictGreaterThan = partial(_comparison_operator, operator=">") + map_StrictLessThan = partial(_comparison_operator, operator="<") + + # }}} @@ -77,6 +103,30 @@ class PymbolicToSympyMapper(PymbolicToSympyLikeMapper): return self.sym.Derivative(self.rec(expr.child), *[self.sym.Symbol(v) for v in expr.variables]) + def map_if(self, expr): + cond = self.rec(expr.condition) + return self.sym.Piecewise((self.rec(expr.then), cond), + (self.rec(expr.else_), self.sym.Not(cond)) + ) + + def map_comparison(self, expr): + left = self.rec(expr.left) + right = self.rec(expr.right) + if expr.operator == "==": + return self.sym.Equality(left, right) + elif expr.operator == "!=": + return self.sym.Unequality(left, right) + elif expr.operator == "<": + return self.sym.StrictLessThan(left, right) + elif expr.operator == ">": + return self.sym.StrictGreaterThan(left, right) + elif expr.operator == "<=": + return self.sym.LessThan(left, right) + elif expr.operator == ">=": + return self.sym.GreaterThan(left, right) + else: + raise NotImplementedError("Unknown operator '%s'" % expr.operator) + # }}} diff --git a/test/test_sympy.py b/test/test_sympy.py index 5415cb6227b82978dbc774d1a50075d8eca2cc6f..e0429bb3eaf327725a1846b3450f4a41391530c7 100644 --- a/test/test_sympy.py +++ b/test/test_sympy.py @@ -114,6 +114,17 @@ def test_pymbolic_to_sympy(): _test_from_pymbolic(mapper, sym, False) +def test_sympy_if_condition(): + pytest.importorskip("sympy") + from pymbolic.interop.sympy import PymbolicToSympyMapper, SympyToPymbolicMapper + forward = PymbolicToSympyMapper() + backward = SympyToPymbolicMapper() + + # Test round trip to sympy and back + expr = prim.If(prim.Comparison(x_, "<=", y_), 1, 0) + assert backward(forward(expr)) == expr + + if __name__ == "__main__": import sys if len(sys.argv) > 1: