From b6d277cdb7e06e179da80b3d06df9749b5a3cfbe Mon Sep 17 00:00:00 2001 From: Dominic Kempf Date: Thu, 7 Sep 2017 21:39:16 +0200 Subject: [PATCH 1/2] Implement pymbolics If statement through sympys Piecewise Previous implementation of If through the EvaluationMapper resulted in wrong behaviour. Although the Piecewise node is no exact match for the if, it can be used to express it. --- pymbolic/interop/sympy.py | 50 +++++++++++++++++++++++++++++++++++++++ test/test_sympy.py | 10 ++++++++ 2 files changed, 60 insertions(+) diff --git a/pymbolic/interop/sympy.py b/pymbolic/interop/sympy.py index b02d5e8..c5e1e05 100644 --- a/pymbolic/interop/sympy.py +++ b/pymbolic/interop/sympy.py @@ -28,6 +28,9 @@ THE SOFTWARE. from pymbolic.interop.common import ( SympyLikeToPymbolicMapper, PymbolicToSympyLikeMapper) +import pymbolic.primitives as prim +from functools import partial + import sympy @@ -60,6 +63,29 @@ class SympyToPymbolicMapper(SympyLikeToPymbolicMapper): def map_long(self, expr): return long(expr) # noqa + def map_Piecewise(self, expr): # noqa + # We only handle piecewises with 2 arguments! + assert len(expr.args) == 2 + # We only handle if/else cases + assert expr.args[0][1] == sympy.Not(expr.args[1][1]) + then = self.rec(expr.args[0][0]) + else_ = self.rec(expr.args[1][0]) + cond = self.rec(expr.args[0][1]) + return prim.If(cond, then, else_) + + def _comparison_operator(self, expr, operator=None): + left = self.rec(expr.args[0]) + right = self.rec(expr.args[1]) + return prim.Comparison(left, operator, right) + + map_Equality = partial(_comparison_operator, operator="==") + map_Unequality = partial(_comparison_operator, operator="!=") + map_GreaterThan = partial(_comparison_operator, operator=">=") + map_LessThan = partial(_comparison_operator, operator="<=") + map_StrictGreaterThan = partial(_comparison_operator, operator=">") + map_StrictLessThan = partial(_comparison_operator, operator="<") + + # }}} @@ -77,6 +103,30 @@ class PymbolicToSympyMapper(PymbolicToSympyLikeMapper): return self.sym.Derivative(self.rec(expr.child), *[self.sym.Symbol(v) for v in expr.variables]) + def map_if(self, expr): + cond = self.rec(expr.condition) + return self.sym.Piecewise((self.rec(expr.then), cond), + (self.rec(expr.else_), self.sym.Not(cond)) + ) + + def map_comparison(self, expr): + left = self.rec(expr.left) + right = self.rec(expr.right) + if expr.operator == "==": + return self.sym.Equality(left, right) + elif expr.operator == "!=": + return self.sym.Unequality(left, right) + elif expr.operator == "<": + return self.sym.StrictLessThan(left, right) + elif expr.operator == ">": + return self.sym.StrictGreaterThan(left, right) + elif expr.operator == "<=": + return self.sym.LessThan(left, right) + elif expr.operator == ">=": + return self.sym.GreaterThan(left, right) + else: + raise NotImplementedError("Cannot understand operator {}".format(expr.operator)) + # }}} diff --git a/test/test_sympy.py b/test/test_sympy.py index 5415cb6..e5a00a8 100644 --- a/test/test_sympy.py +++ b/test/test_sympy.py @@ -114,6 +114,16 @@ def test_pymbolic_to_sympy(): _test_from_pymbolic(mapper, sym, False) +def test_sympy_if_condition(): + from pymbolic.interop.sympy import PymbolicToSympyMapper, SympyToPymbolicMapper + forward = PymbolicToSympyMapper() + backward = SympyToPymbolicMapper() + + # Test round trip to sympy and back + expr = prim.If(prim.Comparison(x_, "<=", y_), 1, 0) + assert backward(forward(expr)) == expr + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab From 7063e1109756eb1b662c29324c9bc2646e1122f7 Mon Sep 17 00:00:00 2001 From: Dominic Kempf Date: Thu, 7 Sep 2017 22:01:11 +0200 Subject: [PATCH 2/2] Fix tests for If -> Piecewise --- pymbolic/interop/sympy.py | 2 +- test/test_sympy.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pymbolic/interop/sympy.py b/pymbolic/interop/sympy.py index c5e1e05..2ccb9fe 100644 --- a/pymbolic/interop/sympy.py +++ b/pymbolic/interop/sympy.py @@ -125,7 +125,7 @@ class PymbolicToSympyMapper(PymbolicToSympyLikeMapper): elif expr.operator == ">=": return self.sym.GreaterThan(left, right) else: - raise NotImplementedError("Cannot understand operator {}".format(expr.operator)) + raise NotImplementedError("Unknown operator '%s'" % expr.operator) # }}} diff --git a/test/test_sympy.py b/test/test_sympy.py index e5a00a8..e0429bb 100644 --- a/test/test_sympy.py +++ b/test/test_sympy.py @@ -115,6 +115,7 @@ def test_pymbolic_to_sympy(): def test_sympy_if_condition(): + pytest.importorskip("sympy") from pymbolic.interop.sympy import PymbolicToSympyMapper, SympyToPymbolicMapper forward = PymbolicToSympyMapper() backward = SympyToPymbolicMapper() -- GitLab