From eebd50385ccc14bac2c63735889efaa6fbf6487c Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 21 Jun 2021 13:44:46 -0500 Subject: [PATCH] Fix piecewise conversion with symengine (#55) * Fix piecewise conversion with symengine * assert -> NotImplementedError and readability fix * Fix cond, else_ order --- pymbolic/interop/common.py | 10 ---------- pymbolic/interop/symengine.py | 11 +++++++++++ pymbolic/interop/sympy.py | 12 ++++++++++++ test/test_sympy.py | 4 +--- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pymbolic/interop/common.py b/pymbolic/interop/common.py index 4ea1e4b..19416d5 100644 --- a/pymbolic/interop/common.py +++ b/pymbolic/interop/common.py @@ -124,16 +124,6 @@ class SympyLikeToPymbolicMapper(SympyLikeMapperBase): else: return SympyLikeMapperBase.not_supported(self, expr) - def map_Piecewise(self, expr): # noqa - # We only handle piecewises with 2 arguments! - assert len(expr.args) == 2 - # We only handle if/else cases - assert expr.args[1][1].is_Boolean and bool(expr.args[1][1]) is True - then = self.rec(expr.args[0][0]) - else_ = self.rec(expr.args[1][0]) - cond = self.rec(expr.args[0][1]) - return prim.If(cond, then, else_) - def _comparison_operator(self, expr, operator=None): left = self.rec(expr.args[0]) right = self.rec(expr.args[1]) diff --git a/pymbolic/interop/symengine.py b/pymbolic/interop/symengine.py index 2fb7849..e5c7ceb 100644 --- a/pymbolic/interop/symengine.py +++ b/pymbolic/interop/symengine.py @@ -67,6 +67,17 @@ class SymEngineToPymbolicMapper(SympyLikeToPymbolicMapper): map_RealDouble = SympyLikeToPymbolicMapper.to_float # noqa: N815 + def map_Piecewise(self, expr): # noqa + # We only handle piecewises with 2 statements! + if not len(expr.args) == 4: + raise NotImplementedError + # We only handle if/else cases + if not (expr.args[3].is_Boolean and bool(expr.args[3]) is True): + raise NotImplementedError + rec_args = [self.rec(arg) for arg in expr.args[:3]] + then, cond, else_ = rec_args + return prim.If(cond, then, else_) + def function_name(self, expr): try: # For FunctionSymbol instances diff --git a/pymbolic/interop/sympy.py b/pymbolic/interop/sympy.py index dc857cc..9ff303d 100644 --- a/pymbolic/interop/sympy.py +++ b/pymbolic/interop/sympy.py @@ -76,6 +76,18 @@ class SympyToPymbolicMapper(SympyLikeToPymbolicMapper): return prim.CommonSubexpression( self.rec(expr.args[0]), expr.prefix, expr.scope) + def map_Piecewise(self, expr): # noqa + # We only handle piecewises with 2 arguments! + if not len(expr.args) == 2: + raise NotImplementedError + # We only handle if/else cases + if not (expr.args[1][1].is_Boolean and bool(expr.args[1][1]) is True): + raise NotImplementedError + then = self.rec(expr.args[0][0]) + else_ = self.rec(expr.args[1][0]) + cond = self.rec(expr.args[0][1]) + return prim.If(cond, then, else_) + # }}} diff --git a/test/test_sympy.py b/test/test_sympy.py index ddeba5a..3f45d31 100644 --- a/test/test_sympy.py +++ b/test/test_sympy.py @@ -134,11 +134,9 @@ def _test_roundtrip(forward, backward, sym, use_symengine): x_[0], x_[i_, j_], prim.Variable("f")(x_), + prim.If(prim.Comparison(x_, "<=", y_), 1, 0), ] - if not use_symengine: - exprs.append(prim.If(prim.Comparison(x_, "<=", y_), 1, 0)) - for expr in exprs: assert expr == backward(forward(expr)) -- GitLab