diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py index 3b871ff00a2d79e6f791300784a2ab0459430279..147a8876dbc2ebae463c597152700df7cacd7c3c 100644 --- a/pymbolic/interop/ast.py +++ b/pymbolic/interop/ast.py @@ -103,7 +103,7 @@ def _mult(x, y): def _neg(x): - return p.Product((-1, x),) + return -x class ASTToPymbolic(ASTMapper): @@ -144,14 +144,14 @@ class ASTToPymbolic(ASTMapper): def map_UnaryOp(self, expr): # noqa try: - op_constructor = self.unary_op_map[expr.op] + op_constructor = self.unary_op_map[type(expr.op)] except KeyError: raise NotImplementedError( "%s does not know how to map operator '%s'" % (type(self).__name__, type(expr.op).__name__)) - return op_constructor(self.rec(expr.left), self.rec(expr.right)) + return op_constructor(self.rec(expr.operand)) def map_IfExp(self, expr): # noqa # (expr test, expr body, expr orelse) diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 7c680ef002616092616caed6029476e75ad06933..ce6f0b8eacbc55b06ac77d5e43b6f3cedf2b9cc7 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -69,6 +69,7 @@ PREC_BITWISE_OR = 7 PREC_COMPARISON = 6 PREC_LOGICAL_AND = 5 PREC_LOGICAL_OR = 4 +PREC_IF = 3 PREC_NONE = 0 @@ -351,16 +352,20 @@ class StringifyMapper(pymbolic.mapper.Mapper): type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs)) def map_if(self, expr, enclosing_prec, *args, **kwargs): - return "If(%s, %s, %s)" % ( - self.rec(expr.condition, PREC_NONE, *args, **kwargs), - self.rec(expr.then, PREC_NONE, *args, **kwargs), - self.rec(expr.else_, PREC_NONE, *args, **kwargs)) + return self.parenthesize_if_needed( + "%s if %s else %s" % ( + self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.condition, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)), + enclosing_prec, PREC_IF) def map_if_positive(self, expr, enclosing_prec, *args, **kwargs): - return "If(%s > 0, %s, %s)" % ( - self.rec(expr.criterion, PREC_NONE, *args, **kwargs), - self.rec(expr.then, PREC_NONE, *args, **kwargs), - self.rec(expr.else_, PREC_NONE, *args, **kwargs)) + return self.parenthesize_if_needed( + "%s if %s > 0 else %s" % ( + self.rec(expr.then, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.criterion, PREC_LOGICAL_OR, *args, **kwargs), + self.rec(expr.else_, PREC_LOGICAL_OR, *args, **kwargs)), + enclosing_prec, PREC_IF) def map_min(self, expr, enclosing_prec, *args, **kwargs): what = type(expr).__name__.lower() diff --git a/pymbolic/parser.py b/pymbolic/parser.py index 9896b524db0793cbed7813d7eeacceebe0d0456a..d0e49cd9b1f967d3c049d000d53f9f6739091a21 100644 --- a/pymbolic/parser.py +++ b/pymbolic/parser.py @@ -60,6 +60,8 @@ _rightshift = intern("rightshift") _and = intern("and") _or = intern("or") _not = intern("not") +_if = intern("if") +_else = intern("else") _bitwiseand = intern("bitwiseand") _bitwiseor = intern("bitwiseor") @@ -68,6 +70,7 @@ _bitwisenot = intern("bitwisenot") _PREC_COMMA = 5 # must be > 1 (1 is used by fortran-to-cl) _PREC_SLICE = 10 +_PREC_IF = 75 _PREC_LOGICAL_OR = 80 _PREC_LOGICAL_AND = 90 @@ -126,6 +129,8 @@ class Parser(object): (_and, pytools.lex.RE(r"and\b")), (_or, pytools.lex.RE(r"or\b")), (_not, pytools.lex.RE(r"not\b")), + (_if, pytools.lex.RE(r"if\b")), + (_else, pytools.lex.RE(r"else\b")), (_imaginary, (_float, pytools.lex.RE("j"))), (_float, ("|", @@ -316,6 +321,17 @@ class Parser(object): pstate.expect(_closebracket) pstate.advance() did_something = True + elif next_tag is _if and _PREC_IF > min_precedence: + from pymbolic.primitives import If + then_expr = left_exp + pstate.advance() + pstate.expect_not_end() + condition = self.parse_expression(pstate, _PREC_LOGICAL_OR) + pstate.expect(_else) + pstate.advance() + else_expr = self.parse_expression(pstate) + left_exp = If(condition, then_expr, else_expr) + did_something = True elif next_tag is _dot and _PREC_CALL > min_precedence: pstate.advance() pstate.expect(_identifier) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 687fb78984f5411257cff40d6ac6b56e08b76604..bc8b6edaea329a18e4a87de95ec4a0c0fe67d9e6 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -25,6 +25,7 @@ THE SOFTWARE. import pymbolic.primitives as prim import pytest from pymbolic import parse +from pytools.lex import ParseError from pymbolic.mapper import IdentityMapper @@ -227,6 +228,21 @@ def test_parser(): strified = StringifyMapper()(expr) assert strified == expr_str, (strified, expr_str) + def assert_parsed_same_as_python(expr_str): + # makes sure that has only one line + expr_str, = expr_str.split('\n') + from pymbolic.interop.ast import ASTToPymbolic + import ast + ast2p = ASTToPymbolic() + try: + expr_parsed_by_python = ast2p(ast.parse(expr_str).body[0].value) + except SyntaxError: + with pytest.raises(ParseError): + parse(expr_str) + else: + expr_parsed_by_pymbolic = parse(expr_str) + assert expr_parsed_by_python == expr_parsed_by_pymbolic + assert_parse_roundtrip("()") assert_parse_roundtrip("(3,)") @@ -264,6 +280,10 @@ def test_parser(): assert parse("f(x,(y,z),z, name=15, name2=17)") == f( x, (y, z), z, name=15, name2=17) + assert_parsed_same_as_python('5+i if i>=0 else (0 if i<-1 else 10)') + assert_parsed_same_as_python("0 if 1 if 2 else 3 else 4") + assert_parsed_same_as_python("0 if (1 if 2 else 3) else 4") + # }}}