From b64346eac09633cc5389e594ed6c0ace96e798bc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 17 May 2015 11:31:39 -0500 Subject: [PATCH] Make Fortran-to-CL with more recent pymbolic --- contrib/fortran-to-opencl/translate.py | 50 ++++++++++++++++---------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/contrib/fortran-to-opencl/translate.py b/contrib/fortran-to-opencl/translate.py index d0355423..66b1273c 100644 --- a/contrib/fortran-to-opencl/translate.py +++ b/contrib/fortran-to-opencl/translate.py @@ -27,7 +27,7 @@ import numpy as np import re from pymbolic.parser import Parser as ExpressionParserBase from pymbolic.mapper import CombineMapper -import pymbolic.primitives +import pymbolic.primitives as p from pymbolic.mapper.c_code import CCodeMapper as CCodeMapperBase from warnings import warn @@ -99,7 +99,7 @@ _and = intern("and") _or = intern("or") -class TypedLiteral(pymbolic.primitives.Leaf): +class TypedLiteral(p.Leaf): def __init__(self, value, dtype): self.value = value self.dtype = np.dtype(dtype) @@ -110,6 +110,18 @@ class TypedLiteral(pymbolic.primitives.Leaf): mapper_method = intern("map_literal") +def simplify_typed_literal(expr): + if (isinstance(expr, p.Product) + and len(expr.children) == 2 + and isinstance(expr.children[1], TypedLiteral) + and p.is_constant(expr.children[0]) + and expr.children[0] == -1): + tl = expr.children[1] + return TypedLiteral("-"+tl.value, tl.dtype) + else: + return expr + + class FortranExpressionParser(ExpressionParserBase): # FIXME double/single prec literals @@ -134,7 +146,6 @@ class FortranExpressionParser(ExpressionParserBase): def parse_terminal(self, pstate): scope = self.tree_walker.scope_stack[-1] - from pymbolic.primitives import Subscript, Call, Variable from pymbolic.parser import ( _identifier, _openpar, _closepar, _float) @@ -164,17 +175,17 @@ class FortranExpressionParser(ExpressionParserBase): # not a subscript scope.use_name(name) - return Variable(name) + return p.Variable(name) - left_exp = Variable(name) + left_exp = p.Variable(name) pstate.advance() pstate.expect_not_end() if scope.is_known(name): - cls = Subscript + cls = p.Subscript else: - cls = Call + cls = p.Call if pstate.next_tag is _closepar: pstate.advance() @@ -219,14 +230,14 @@ class FortranExpressionParser(ExpressionParserBase): _PREC_CALL, _PREC_COMPARISON, _openpar, _PREC_LOGICAL_OR, _PREC_LOGICAL_AND) from pymbolic.primitives import ( - ComparisonOperator, LogicalAnd, LogicalOr) + Comparison, LogicalAnd, LogicalOr) next_tag = pstate.next_tag() if next_tag is _openpar and _PREC_CALL > min_precedence: raise TranslationError("parenthesis operator only works on names") elif next_tag in self.COMP_MAP and _PREC_COMPARISON > min_precedence: pstate.advance() - left_exp = ComparisonOperator( + left_exp = Comparison( left_exp, self.COMP_MAP[next_tag], self.parse_expression(pstate, _PREC_COMPARISON)) @@ -250,7 +261,10 @@ class FortranExpressionParser(ExpressionParserBase): assert len(left_exp) == 2 r, i = left_exp - dtype = (r.dtype.type(0) + i.dtype.type(0)) + r = simplify_typed_literal(r) + i = simplify_typed_literal(i) + + dtype = (r.dtype.type(0) + i.dtype.type(0)).dtype if dtype == np.float32: dtype = np.complex64 else: @@ -758,10 +772,9 @@ class ArgumentAnalayzer(FTreeWalkerBase): lhs = self.parse_expr(node.variable) - from pymbolic.primitives import Subscript, Call - if isinstance(lhs, Subscript): + if isinstance(lhs, p.Subscript): lhs_name = lhs.aggregate.name - elif isinstance(lhs, Call): + elif isinstance(lhs, p.Call): # in absence of dim info, subscripts get parsed as calls lhs_name = lhs.function.name else: @@ -797,11 +810,10 @@ class ArgumentAnalayzer(FTreeWalkerBase): def map_Call(self, node): scope = self.scope_stack[-1] - from pymbolic.primitives import Subscript, Variable for i, arg_str in enumerate(node.items): arg = self.parse_expr(arg_str) - if isinstance(arg, (Variable, Subscript)): - if isinstance(arg, Subscript): + if isinstance(arg, (p.Variable, p.Subscript)): + if isinstance(arg, p.Subscript): arg_name = arg.aggregate.name else: arg_name = arg.name @@ -926,9 +938,9 @@ class F2CLTranslator(FTreeWalkerBase): if shape is not None: result.append(cgen.Statement( "%s %s[nitemsof(%s)]" - % ( - dtype_to_ctype(scope.get_type(name)), - name, name))) + % ( + dtype_to_ctype(scope.get_type(name)), + name, name))) else: result.append(self.get_declarator(name)) -- GitLab