diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 4ff3adbc0203d4338c885d2cd0cfa79da0caf722..f99c01c6252e082f7935f4e55fe47285caed02e0 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -130,6 +130,18 @@ class CombineMapper(RecursiveMapper): self.rec(coeff, *args) for exp, coeff in expr.data) ) + def map_logical_and(self, expr, *args): + return self.combine(self.rec(child, *args) + for child in expr.children) + + map_logical_or = map_logical_and + map_logical_not = map_negation + + def map_comparison(self, expr, *args): + return self.combine(( + self.rec(expr.left, *args), + self.rec(expr.right, *args))) + def map_list(self, expr, *args): return self.combine(self.rec(child, *args) for child in expr) @@ -147,6 +159,11 @@ class CombineMapper(RecursiveMapper): self.rec(expr.then), self.rec(expr.else_)]) + def map_if(self, expr): + return self.combine([ + self.rec(expr.condition), + self.rec(expr.then), + self.rec(expr.else_)]) # }}} @@ -209,6 +226,23 @@ class IdentityMapperBase(object): ((exp, self.rec(coeff, *args)) for exp, coeff in expr.data)) + def map_logical_and(self, expr, *args): + return type(expr)(tuple( + self.rec(child, *args) for child in expr.children)) + + map_logical_or = map_logical_and + + def map_logical_not(self, expr, *args): + from pymbolic.primitives import LogicalNot + return LogicalNot( + self.rec(expr.child, *args)) + + def map_comparison(self, expr, *args): + return type(expr)( + self.rec(expr.left, *args), + expr.operator, + self.rec(expr.right, *args)) + def map_list(self, expr, *args): return [self.rec(child, *args) for child in expr] @@ -235,12 +269,16 @@ class IdentityMapperBase(object): **expr.get_extra_properties()) def map_if_positive(self, expr): - return expr.__class__( + return type(expr)( self.rec(expr.criterion), self.rec(expr.then), - self.rec(expr.else_), - ) + self.rec(expr.else_)) + def map_if(self, expr): + return type(expr)( + self.rec(expr.condition), + self.rec(expr.then), + self.rec(expr.else_)) diff --git a/pymbolic/mapper/c_code.py b/pymbolic/mapper/c_code.py index 0c06625804154d83db0acd505a9ee4d978a2fd18..5be35a6fc6e45270a01c8811df6b3f68395a2d9e 100644 --- a/pymbolic/mapper/c_code.py +++ b/pymbolic/mapper/c_code.py @@ -123,3 +123,11 @@ class CCodeMapper(SimplifyingSortingStringifyMapper): self.rec(expr.else_, PREC_NONE), ) + def map_if(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_NONE + return self.format("(%s ? %s : %s)", + self.rec(expr.condition, PREC_NONE), + self.rec(expr.then, PREC_NONE), + self.rec(expr.else_, PREC_NONE), + ) + diff --git a/pymbolic/parser.py b/pymbolic/parser.py index bd53de130775dcd74b44ea72d7b5653f00687998..751cd14301d7bd46d7b94d0c342a5f6256c266c8 100644 --- a/pymbolic/parser.py +++ b/pymbolic/parser.py @@ -17,6 +17,17 @@ _whitespace = intern("whitespace") _comma = intern("comma") _dot = intern("dot") +_equal = intern("equal") +_notequal = intern("notequal") +_less = intern("less") +_lessequal = intern("lessequal") +_greater = intern("greater") +_greaterequal = intern("greaterequal") + +_and = intern("and") +_or = intern("or") +_not = intern("not") + _PREC_COMMA = 5 # must be > 1 (1 is used by fortran-to-cl) _PREC_LOGICAL_OR = 80 _PREC_LOGICAL_AND = 90 @@ -32,6 +43,17 @@ _PREC_CALL = 150 class Parser: lex_table = [ + (_equal, pytools.lex.RE(r"==")), + (_notequal, pytools.lex.RE(r"!=")), + (_less, pytools.lex.RE(r"\<")), + (_lessequal, pytools.lex.RE(r"\<=")), + (_greater, pytools.lex.RE(r"\>")), + (_greaterequal, pytools.lex.RE(r"\>=")), + + (_and, pytools.lex.RE(r"and")), + (_or, pytools.lex.RE(r"or")), + (_not, pytools.lex.RE(r"not")), + (_imaginary, (_float, pytools.lex.RE("j"))), (_float, ("|", pytools.lex.RE(r"[+-]?[0-9]+\.[0-9]*([eEdD][+-]?[0-9]+)?"), @@ -54,6 +76,15 @@ class Parser: (_dot, pytools.lex.RE(r"\.")), ] + _COMP_TABLE = { + _greater: ">", + _greaterequal: ">=", + _less: "<", + _lessequal: "<=", + _equal: "==", + _notequal: "!=", + } + def parse_terminal(self, pstate): import pymbolic.primitives as primitives @@ -83,6 +114,11 @@ class Parser: elif pstate.is_next(_minus): pstate.advance() left_exp = -self.parse_expression(pstate, _PREC_UNARY) + elif pstate.is_next(_not): + pstate.advance() + from pymbolic.primitives import LogicalNot + left_exp = LogicalNot( + self.parse_expression(pstate, _PREC_UNARY)) elif pstate.is_next(_openpar): pstate.advance() left_exp = self.parse_expression(pstate) @@ -102,8 +138,9 @@ class Parser: if pstate.is_at_end(): return left_exp - left_exp, did_something = self.parse_postfix( + result = self.parse_postfix( pstate, min_precedence, left_exp) + left_exp, did_something = result return left_exp @@ -161,18 +198,41 @@ class Parser: pstate.advance() left_exp **= self.parse_expression(pstate, _PREC_POWER) did_something = True + elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence: + pstate.advance() + from pymbolic.primitives import LogicalAnd + left_exp = LogicalAnd(( + left_exp, + self.parse_expression(pstate, _PREC_LOGICAL_AND))) + did_something = True + elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence: + pstate.advance() + from pymbolic.primitives import LogicalOr + left_exp = LogicalOr(( + left_exp, + self.parse_expression(pstate, _PREC_LOGICAL_OR))) + did_something = True + elif next_tag in self._COMP_TABLE and _PREC_COMPARISON > min_precedence: + pstate.advance() + from pymbolic.primitives import ComparisonOperator + left_exp = ComparisonOperator( + left_exp, + self._COMP_TABLE[next_tag], + self.parse_expression(pstate, _PREC_COMPARISON)) + did_something = True elif next_tag is _comma and _PREC_COMMA > min_precedence: # The precedence makes the comma left-associative. pstate.advance() if pstate.is_at_end() or pstate.next_tag() is _closepar: - return (left_exp,) - - new_el = self.parse_expression(pstate, _PREC_COMMA) - if isinstance(left_exp, tuple): - left_exp = left_exp + (new_el,) + left_exp = (left_exp,) else: - left_exp = (left_exp, new_el) + new_el = self.parse_expression(pstate, _PREC_COMMA) + if isinstance(left_exp, tuple): + left_exp = left_exp + (new_el,) + else: + left_exp = (left_exp, new_el) + did_something = True return left_exp, did_something diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index b4604294c6ebdc841891a0c567b198144d4f5913..09c51a78d205822d40e37cb54caeb832f1763eaf 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -503,7 +503,7 @@ class ComparisonOperator(Expression): def __init__(self, left, operator, right): self.left = left self.right = right - if not operator in [">", ">=", "==", "<", "<="]: + if not operator in [">", ">=", "==", "!=", "<", "<="]: raise RuntimeError("invalid operator") self.operator = operator @@ -558,8 +558,26 @@ class LogicalAnd(BooleanExpression): +class If(Expression): + def __init__(self, criterion, then, else_): + self.condition = criterion + self.then = then + self.else_ = else_ + + def __getinitargs__(self): + return self.condition, self.then, self.else_ + + mapper_method = intern("map_if") + + + + class IfPositive(Expression): def __init__(self, criterion, then, else_): + from warnings import warn + warn("IfPositive is deprecated, use If( ... >0)", DeprecationWarning, + stacklevel=2) + self.criterion = criterion self.then = then self.else_ = else_ @@ -672,12 +690,12 @@ class CommonSubexpression(Expression): mapper_method = intern("map_common_subexpression") +# }}} - -# intelligent makers --------------------------------------------------------- +# intelligent factory functions ---------------------------------------------- def make_variable(var_or_string): if not isinstance(var_or_string, Expression): return Variable(var_or_string)