From db599e5bb2ecd8098fbf660c53cf35ca71a37bb1 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 21 Jan 2012 15:27:49 -0500
Subject: [PATCH] Modularize parser. Add support for logical and comparison
 operators.

---
 pymbolic/mapper/stringifier.py |  44 ++++--
 pymbolic/parser.py             | 246 ++++++++++++++++++---------------
 pymbolic/primitives.py         |   1 +
 3 files changed, 173 insertions(+), 118 deletions(-)

diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py
index b6c5c0b..e40bad7 100644
--- a/pymbolic/mapper/stringifier.py
+++ b/pymbolic/mapper/stringifier.py
@@ -3,11 +3,14 @@ import pymbolic.mapper
 
 
 
-PREC_CALL = 5
-PREC_POWER = 4
-PREC_UNARY = 3
-PREC_PRODUCT = 2
-PREC_SUM = 1
+PREC_CALL = 15
+PREC_POWER = 14
+PREC_UNARY = 13
+PREC_PRODUCT = 12
+PREC_SUM = 11
+PREC_COMPARISON = 10
+PREC_LOGICAL_AND = 9
+PREC_LOGICAL_OR = 8
 PREC_NONE = 0
 
 
@@ -92,17 +95,19 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper):
 
     def map_quotient(self, expr, enclosing_prec):
         return self.parenthesize_if_needed(
-                self.format("%s/%s",
+                self.format("%s / %s",
+                    # space is necessary--otherwise '/*' becomes
+                    # start-of-comment in C.
                     self.rec(expr.numerator, PREC_PRODUCT),
                     self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1}
                 enclosing_prec, PREC_PRODUCT)
 
     def map_floor_div(self, expr, enclosing_prec):
         return self.parenthesize_if_needed(
-                self.format("%s//%s",
+                self.format("%s // %s",
                     self.rec(expr.numerator, PREC_PRODUCT),
                     self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1}
-                enclosing_prec, PREC_SUM)
+                enclosing_prec, PREC_PRODUCT)
 
     def map_power(self, expr, enclosing_prec):
         return self.parenthesize_if_needed(
@@ -122,6 +127,29 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper):
             [coeff*expr.base**exp for exp, coeff in expr.data[::-1]]),
             enclosing_prec)
 
+    def map_comparison(self, expr, enclosing_prec):
+        return self.parenthesize_if_needed(
+                self.format("%s %s %s",
+                    self.rec(expr.left, PREC_COMPARISON),
+                    expr.operator,
+                    self.rec(expr.right, PREC_COMPARISON)),
+                enclosing_prec, PREC_COMPARISON)
+
+    def map_logical_not(self, expr, enclosing_prec):
+        return self.parenthesize_if_needed(
+                self.rec(expr.child, PREC_UNARY),
+                enclosing_prec, PREC_UNARY)
+
+    def map_logical_and(self, expr, enclosing_prec):
+        return self.parenthesize_if_needed(
+                self.join_rec(" && ", expr.children, PREC_LOGICAL_AND),
+                enclosing_prec, PREC_LOGICAL_AND)
+
+    def map_logical_or(self, expr, enclosing_prec):
+        return self.parenthesize_if_needed(
+                self.join_rec(" || ", expr.children, PREC_LOGICAL_OR),
+                enclosing_prec, PREC_LOGICAL_OR)
+
     def map_list(self, expr, enclosing_prec):
         return self.format("[%s]", self.join_rec(", ", expr, PREC_NONE))
 
diff --git a/pymbolic/parser.py b/pymbolic/parser.py
index bfc4f55..f40bb06 100644
--- a/pymbolic/parser.py
+++ b/pymbolic/parser.py
@@ -17,45 +17,52 @@ _whitespace = intern("whitespace")
 _comma = intern("comma")
 _dot = intern("dot")
 
-_LEX_TABLE = [
-    (_imaginary, (_float, pytools.lex.RE("j"))),
-    (_float, ("|", 
-               pytools.lex.RE(r"[0-9]+\.[0-9]*([eE]-?[0-9]+)?"),
-               pytools.lex.RE(r"[0-9]+(\.[0-9]*)?[eE]-?[0-9]+"),
-               pytools.lex.RE(r"[0-9]*\.[0-9]+([eE]-?[0-9]+)?"),
-               pytools.lex.RE(r"[0-9]*(\.[0-9]+)?[eE]-?[0-9]+"))),
-    (_int, pytools.lex.RE(r"[0-9]+")),
-    (_plus, pytools.lex.RE(r"\+")),
-    (_minus, pytools.lex.RE(r"-")),
-    (_power, pytools.lex.RE(r"\*\*")),
-    (_times, pytools.lex.RE(r"\*")),
-    (_over, pytools.lex.RE(r"/")),
-    (_openpar, pytools.lex.RE(r"\(")),
-    (_closepar, pytools.lex.RE(r"\)")),
-    (_openbracket, pytools.lex.RE(r"\[")),
-    (_closebracket, pytools.lex.RE(r"\]")),
-    (_identifier, pytools.lex.RE(r"[@a-z_A-Z_][@a-zA-Z_0-9]*")),
-    (_whitespace, pytools.lex.RE("[ \n\t]*")),
-    (_comma, pytools.lex.RE(",")),
-    (_dot, pytools.lex.RE(r"\.")),
-    ]
-
-_PREC_COMMA = 5
-_PREC_PLUS = 10
-_PREC_TIMES = 20
-_PREC_POWER = 30
-_PREC_UNARY_MINUS = 40
-_PREC_CALL = 50
-
-def parse(expr_str):
-    import pymbolic.primitives as primitives
-
-    def parse_terminal(pstate):
+_PREC_COMMA = 5 # must be > 0
+_PREC_LOGICAL_OR = 80
+_PREC_LOGICAL_AND = 90
+_PREC_COMPARISON = 100
+_PREC_PLUS = 110
+_PREC_TIMES = 120
+_PREC_POWER = 130
+_PREC_UNARY = 140
+_PREC_CALL = 150
+
+
+
+
+class Parser:
+    lex_table = [
+            (_imaginary, (_float, pytools.lex.RE("j"))),
+            (_float, ("|",
+                pytools.lex.RE(r"[0-9]+\.[0-9]*([eEdD][+-]?[0-9]+)?"),
+                pytools.lex.RE(r"[0-9]+(\.[0-9]*)?[eEdD][+-]?[0-9]+"),
+                pytools.lex.RE(r"[0-9]*\.[0-9]+([eEdD][+-]?[0-9]+)?"),
+                pytools.lex.RE(r"[0-9]*(\.[0-9]+)?[eEdD][+-]?[0-9]+"))),
+            (_int, pytools.lex.RE(r"[0-9]+")),
+            (_plus, pytools.lex.RE(r"\+")),
+            (_minus, pytools.lex.RE(r"-")),
+            (_power, pytools.lex.RE(r"\*\*")),
+            (_times, pytools.lex.RE(r"\*")),
+            (_over, pytools.lex.RE(r"/")),
+            (_openpar, pytools.lex.RE(r"\(")),
+            (_closepar, pytools.lex.RE(r"\)")),
+            (_openbracket, pytools.lex.RE(r"\[")),
+            (_closebracket, pytools.lex.RE(r"\]")),
+            (_identifier, pytools.lex.RE(r"[@a-z_A-Z_][@a-zA-Z_0-9]*")),
+            (_whitespace, pytools.lex.RE("[ \n\t]*")),
+            (_comma, pytools.lex.RE(",")),
+            (_dot, pytools.lex.RE(r"\.")),
+            ]
+
+    def parse_terminal(self, pstate):
+        import pymbolic.primitives as primitives
+
         next_tag = pstate.next_tag()
         if next_tag is _int:
             return int(pstate.next_str_and_advance())
         elif next_tag is _float:
-            return float(pstate.next_str_and_advance())
+            return float(pstate.next_str_and_advance()
+                    .replace("d", "e").replace("D", "e"))
         elif next_tag is _imaginary:
             return complex(pstate.next_str_and_advance())
         elif next_tag is _identifier:
@@ -63,22 +70,31 @@ def parse(expr_str):
         else:
             pstate.expected("terminal")
 
-    def parse_expression(pstate, min_precedence=0):
+    def parse_prefix(self, pstate):
+        import pymbolic.primitives as primitives
         pstate.expect_not_end()
 
         if pstate.is_next(_times):
             pstate.advance()
             left_exp = primitives.Wildcard()
+        elif pstate.is_next(_plus):
+            pstate.advance()
+            left_exp = self.parse_expression(pstate, _PREC_UNARY)
         elif pstate.is_next(_minus):
             pstate.advance()
-            left_exp = -parse_expression(pstate, _PREC_UNARY_MINUS)
+            left_exp = -self.parse_expression(pstate, _PREC_UNARY)
         elif pstate.is_next(_openpar):
             pstate.advance()
-            left_exp = parse_expression(pstate)
+            left_exp = self.parse_expression(pstate)
             pstate.expect(_closepar)
             pstate.advance()
         else:
-            left_exp = parse_terminal(pstate)
+            left_exp = self.parse_terminal(pstate)
+
+        return left_exp
+
+    def parse_expression(self, pstate, min_precedence=0):
+        left_exp = self.parse_prefix(pstate)
 
         did_something = True
         while did_something:
@@ -86,80 +102,90 @@ def parse(expr_str):
             if pstate.is_at_end():
                 return left_exp
 
-            next_tag = pstate.next_tag()
-
-            if next_tag is _openpar and _PREC_CALL > min_precedence:
-                pstate.advance()
-                pstate.expect_not_end()
-                if pstate.next_tag is _closepar:
-                    pstate.advance()
-                    left_exp = primitives.Call(left_exp, ())
-                else:
-                    args = parse_expression(pstate)
-                    if not isinstance(args, tuple):
-                        args = (args,)
-                    left_exp = primitives.Call(left_exp, args)
-                    pstate.expect(_closepar)
-                    pstate.advance()
-                did_something = True
-            elif next_tag is _openbracket and _PREC_CALL > min_precedence:
-                pstate.advance()
-                pstate.expect_not_end()
-                left_exp = primitives.Subscript(left_exp, parse_expression(pstate))
-                pstate.expect(_closebracket)
-                pstate.advance()
-                did_something = True
-            elif next_tag is _dot and _PREC_CALL > min_precedence:
-                pstate.advance()
-                pstate.expect(_identifier)
-                left_exp = primitives.Lookup(left_exp, pstate.next_str())
-                pstate.advance()
-                did_something = True
-            elif next_tag is _plus and _PREC_PLUS > min_precedence:
-                pstate.advance()
-                left_exp += parse_expression(pstate, _PREC_PLUS)
-                did_something = True
-            elif next_tag is _minus and _PREC_PLUS > min_precedence:
-                pstate.advance()
-                left_exp -= parse_expression(pstate, _PREC_PLUS)
-                did_something = True
-            elif next_tag is _times and _PREC_TIMES > min_precedence:
-                pstate.advance()
-                left_exp *= parse_expression(pstate, _PREC_TIMES)
-                did_something = True
-            elif next_tag is _over and _PREC_TIMES > min_precedence:
-                pstate.advance()
-                left_exp /= parse_expression(pstate, _PREC_TIMES)
-                did_something = True
-            elif next_tag is _power and _PREC_POWER > min_precedence:
-                pstate.advance()
-                left_exp **= parse_expression(pstate, _PREC_POWER)
-                did_something = True
-            elif next_tag is _comma and _PREC_COMMA > min_precedence:
-                # The precedence makes the comma left-associative.
-
-                pstate.advance()
-                if pstate.is_at_end() or pstate.next_tag() is _closepar:
-                    return (left_exp,)
-
-                new_el = parse_expression(pstate, _PREC_COMMA)
-                if isinstance(left_exp, tuple):
-                    left_exp = left_exp + (new_el,)
-                else:
-                    left_exp = (left_exp, new_el)
-                did_something = True
+            left_exp, did_something = self.parse_postfix(
+                    pstate, min_precedence, left_exp)
 
         return left_exp
 
-    pstate = pytools.lex.LexIterator(
-        [(tag, s, idx) 
-         for (tag, s, idx) in pytools.lex.lex(_LEX_TABLE, expr_str)
-         if tag is not _whitespace], expr_str)
+    def parse_postfix(self, pstate, min_precedence, left_exp):
+        import pymbolic.primitives as primitives
 
-    result = parse_expression(pstate)
-    if not pstate.is_at_end():
-        print pstate.next_tag()
-        pstate.raise_parse_error("leftover input after completed parse")
-    return result
+        did_something = False
 
+        next_tag = pstate.next_tag()
 
+        if next_tag is _openpar and _PREC_CALL > min_precedence:
+            pstate.advance()
+            pstate.expect_not_end()
+            if next_tag is _closepar:
+                pstate.advance()
+                left_exp = primitives.Call(left_exp, ())
+            else:
+                args = self.parse_expression(pstate)
+                if not isinstance(args, tuple):
+                    args = (args,)
+                left_exp = primitives.Call(left_exp, args)
+                pstate.expect(_closepar)
+                pstate.advance()
+            did_something = True
+        elif next_tag is _openbracket and _PREC_CALL > min_precedence:
+            pstate.advance()
+            pstate.expect_not_end()
+            left_exp = primitives.Subscript(left_exp, self.parse_expression(pstate))
+            pstate.expect(_closebracket)
+            pstate.advance()
+            did_something = True
+        elif next_tag is _dot and _PREC_CALL > min_precedence:
+            pstate.advance()
+            pstate.expect(_identifier)
+            left_exp = primitives.Lookup(left_exp, pstate.next_str())
+            pstate.advance()
+            did_something = True
+        elif next_tag is _plus and _PREC_PLUS > min_precedence:
+            pstate.advance()
+            left_exp += self.parse_expression(pstate, _PREC_PLUS)
+            did_something = True
+        elif next_tag is _minus and _PREC_PLUS > min_precedence:
+            pstate.advance()
+            left_exp -= self.parse_expression(pstate, _PREC_PLUS)
+            did_something = True
+        elif next_tag is _times and _PREC_TIMES > min_precedence:
+            pstate.advance()
+            left_exp *= self.parse_expression(pstate, _PREC_TIMES)
+            did_something = True
+        elif next_tag is _over and _PREC_TIMES > min_precedence:
+            pstate.advance()
+            left_exp /= self.parse_expression(pstate, _PREC_TIMES)
+            did_something = True
+        elif next_tag is _power and _PREC_POWER > min_precedence:
+            pstate.advance()
+            left_exp **= self.parse_expression(pstate, _PREC_POWER)
+            did_something = True
+        elif next_tag is _comma and _PREC_COMMA > min_precedence:
+            # The precedence makes the comma left-associative.
+
+            pstate.advance()
+            if pstate.is_at_end() or pstate.next_tag() is _closepar:
+                return (left_exp,)
+
+            new_el = self.parse_expression(pstate, _PREC_COMMA)
+            if isinstance(left_exp, tuple):
+                left_exp = left_exp + (new_el,)
+            else:
+                left_exp = (left_exp, new_el)
+            did_something = True
+
+        return left_exp, did_something
+
+    def __call__(self, expr_str):
+        pstate = pytools.lex.LexIterator(
+            [(tag, s, idx)
+             for (tag, s, idx) in pytools.lex.lex(self.lex_table, expr_str)
+             if tag is not _whitespace], expr_str)
+
+        result = self. parse_expression(pstate)
+        if not pstate.is_at_end():
+            pstate.raise_parse_error("leftover input after completed parse")
+        return result
+
+parse = Parser()
diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index 07e9467..b460429 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -505,6 +505,7 @@ class ComparisonOperator(Expression):
         self.right = right
         if not operator in [">", ">=", "==", "<", "<="]:
             raise RuntimeError("invalid operator")
+        self.operator = operator
 
     def __getinitargs__(self):
         return self.left, self.operator, self.right
-- 
GitLab