From b64346eac09633cc5389e594ed6c0ace96e798bc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 17 May 2015 11:31:39 -0500
Subject: [PATCH] Make Fortran-to-CL with more recent pymbolic

---
 contrib/fortran-to-opencl/translate.py | 50 ++++++++++++++++----------
 1 file changed, 31 insertions(+), 19 deletions(-)

diff --git a/contrib/fortran-to-opencl/translate.py b/contrib/fortran-to-opencl/translate.py
index d0355423..66b1273c 100644
--- a/contrib/fortran-to-opencl/translate.py
+++ b/contrib/fortran-to-opencl/translate.py
@@ -27,7 +27,7 @@ import numpy as np
 import re
 from pymbolic.parser import Parser as ExpressionParserBase
 from pymbolic.mapper import CombineMapper
-import pymbolic.primitives
+import pymbolic.primitives as p
 from pymbolic.mapper.c_code import CCodeMapper as CCodeMapperBase
 
 from warnings import warn
@@ -99,7 +99,7 @@ _and = intern("and")
 _or = intern("or")
 
 
-class TypedLiteral(pymbolic.primitives.Leaf):
+class TypedLiteral(p.Leaf):
     def __init__(self, value, dtype):
         self.value = value
         self.dtype = np.dtype(dtype)
@@ -110,6 +110,18 @@ class TypedLiteral(pymbolic.primitives.Leaf):
     mapper_method = intern("map_literal")
 
 
+def simplify_typed_literal(expr):
+    if (isinstance(expr, p.Product)
+            and len(expr.children) == 2
+            and isinstance(expr.children[1], TypedLiteral)
+            and p.is_constant(expr.children[0])
+            and expr.children[0] == -1):
+        tl = expr.children[1]
+        return TypedLiteral("-"+tl.value, tl.dtype)
+    else:
+        return expr
+
+
 class FortranExpressionParser(ExpressionParserBase):
     # FIXME double/single prec literals
 
@@ -134,7 +146,6 @@ class FortranExpressionParser(ExpressionParserBase):
     def parse_terminal(self, pstate):
         scope = self.tree_walker.scope_stack[-1]
 
-        from pymbolic.primitives import Subscript, Call, Variable
         from pymbolic.parser import (
             _identifier, _openpar, _closepar, _float)
 
@@ -164,17 +175,17 @@ class FortranExpressionParser(ExpressionParserBase):
                 # not a subscript
                 scope.use_name(name)
 
-                return Variable(name)
+                return p.Variable(name)
 
-            left_exp = Variable(name)
+            left_exp = p.Variable(name)
 
             pstate.advance()
             pstate.expect_not_end()
 
             if scope.is_known(name):
-                cls = Subscript
+                cls = p.Subscript
             else:
-                cls = Call
+                cls = p.Call
 
             if pstate.next_tag is _closepar:
                 pstate.advance()
@@ -219,14 +230,14 @@ class FortranExpressionParser(ExpressionParserBase):
                 _PREC_CALL, _PREC_COMPARISON, _openpar,
                 _PREC_LOGICAL_OR, _PREC_LOGICAL_AND)
         from pymbolic.primitives import (
-                ComparisonOperator, LogicalAnd, LogicalOr)
+                Comparison, LogicalAnd, LogicalOr)
 
         next_tag = pstate.next_tag()
         if next_tag is _openpar and _PREC_CALL > min_precedence:
             raise TranslationError("parenthesis operator only works on names")
         elif next_tag in self.COMP_MAP and _PREC_COMPARISON > min_precedence:
             pstate.advance()
-            left_exp = ComparisonOperator(
+            left_exp = Comparison(
                     left_exp,
                     self.COMP_MAP[next_tag],
                     self.parse_expression(pstate, _PREC_COMPARISON))
@@ -250,7 +261,10 @@ class FortranExpressionParser(ExpressionParserBase):
                 assert len(left_exp) == 2
                 r, i = left_exp
 
-                dtype = (r.dtype.type(0) + i.dtype.type(0))
+                r = simplify_typed_literal(r)
+                i = simplify_typed_literal(i)
+
+                dtype = (r.dtype.type(0) + i.dtype.type(0)).dtype
                 if dtype == np.float32:
                     dtype = np.complex64
                 else:
@@ -758,10 +772,9 @@ class ArgumentAnalayzer(FTreeWalkerBase):
 
         lhs = self.parse_expr(node.variable)
 
-        from pymbolic.primitives import Subscript, Call
-        if isinstance(lhs, Subscript):
+        if isinstance(lhs, p.Subscript):
             lhs_name = lhs.aggregate.name
-        elif isinstance(lhs, Call):
+        elif isinstance(lhs, p.Call):
             # in absence of dim info, subscripts get parsed as calls
             lhs_name = lhs.function.name
         else:
@@ -797,11 +810,10 @@ class ArgumentAnalayzer(FTreeWalkerBase):
     def map_Call(self, node):
         scope = self.scope_stack[-1]
 
-        from pymbolic.primitives import Subscript, Variable
         for i, arg_str in enumerate(node.items):
             arg = self.parse_expr(arg_str)
-            if isinstance(arg, (Variable, Subscript)):
-                if isinstance(arg, Subscript):
+            if isinstance(arg, (p.Variable, p.Subscript)):
+                if isinstance(arg, p.Subscript):
                     arg_name = arg.aggregate.name
                 else:
                     arg_name = arg.name
@@ -926,9 +938,9 @@ class F2CLTranslator(FTreeWalkerBase):
                     if shape is not None:
                         result.append(cgen.Statement(
                             "%s %s[nitemsof(%s)]"
-                                % (
-                                    dtype_to_ctype(scope.get_type(name)),
-                                    name, name)))
+                            % (
+                                dtype_to_ctype(scope.get_type(name)),
+                                name, name)))
                     else:
                         result.append(self.get_declarator(name))
 
-- 
GitLab