From b75f8abdf711309a328a7593d9280b13328bbc52 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 11 May 2015 14:52:00 -0500
Subject: [PATCH] Fortran parsing improvements

---
 loopy/frontend/fortran/__init__.py   |  8 +++++---
 loopy/frontend/fortran/translator.py | 29 ++++++++++++++++------------
 loopy/frontend/fortran/tree.py       | 15 ++++++++++----
 setup.cfg                            |  2 +-
 4 files changed, 34 insertions(+), 20 deletions(-)

diff --git a/loopy/frontend/fortran/__init__.py b/loopy/frontend/fortran/__init__.py
index aec2d4314..1a610b63c 100644
--- a/loopy/frontend/fortran/__init__.py
+++ b/loopy/frontend/fortran/__init__.py
@@ -26,7 +26,8 @@ from loopy.diagnostic import LoopyError
 
 
 def f2loopy(source, free_form=True, strict=True,
-        pre_transform_code=None, use_c_preprocessor=False,
+        pre_transform_code=None, pre_transform_code_context=None,
+        use_c_preprocessor=False,
         file_name="<floopy code>"):
     if use_c_preprocessor:
         try:
@@ -60,9 +61,10 @@ def f2loopy(source, free_form=True, strict=True,
             analyze=False, ignore_comments=False)
 
     from loopy.frontend.fortran.translator import F2LoopyTranslator
-    f2loopy = F2LoopyTranslator()
+    f2loopy = F2LoopyTranslator(file_name)
     f2loopy(tree)
 
-    return f2loopy.make_kernels(pre_transform_code=pre_transform_code)
+    return f2loopy.make_kernels(pre_transform_code=pre_transform_code,
+            pre_transform_code_context=pre_transform_code_context)
 
 # vim: foldmethod=marker
diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py
index a6b5b422b..abd925c14 100644
--- a/loopy/frontend/fortran/translator.py
+++ b/loopy/frontend/fortran/translator.py
@@ -214,7 +214,7 @@ def remove_common_indentation(lines):
 # {{{ translator
 
 class F2LoopyTranslator(FTreeWalkerBase):
-    def __init__(self):
+    def __init__(self, filename):
         FTreeWalkerBase.__init__(self)
 
         self.scope_stack = []
@@ -234,6 +234,8 @@ class F2LoopyTranslator(FTreeWalkerBase):
 
         self.transform_code_lines = []
 
+        self.filename = filename
+
     def add_expression_instruction(self, lhs, rhs):
         scope = self.scope_stack[-1]
 
@@ -331,7 +333,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
 
         tp = self.dtype_from_stmt(node)
 
-        for name, shape in self.parse_dimension_specs(node.entity_decls):
+        for name, shape in self.parse_dimension_specs(node, node.entity_decls):
             if shape is not None:
                 assert name not in scope.dim_map
                 scope.dim_map[name] = shape
@@ -350,7 +352,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
     def map_Dimension(self, node):
         scope = self.scope_stack[-1]
 
-        for name, shape in self.parse_dimension_specs(node.items):
+        for name, shape in self.parse_dimension_specs(node, node.items):
             if shape is not None:
                 assert name not in scope.dim_map
                 scope.dim_map[name] = shape
@@ -369,7 +371,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
         for name, data in node.stmts:
             name, = name
             assert name not in scope.data
-            scope.data[name] = [self.parse_expr(i) for i in data]
+            scope.data[name] = [self.parse_expr(node, i) for i in data]
 
         return []
 
@@ -399,7 +401,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
         scope = self.scope_stack[-1]
 
         lhs = scope.process_expression_for_loopy(
-                self.parse_expr(node.variable))
+                self.parse_expr(node, node.variable))
         from pymbolic.primitives import Subscript, Call
         if isinstance(lhs, Call):
             raise TranslationError("function call (to '%s') on left hand side of"
@@ -411,7 +413,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
 
         scope.use_name(lhs_name)
 
-        rhs = scope.process_expression_for_loopy(self.parse_expr(node.expr))
+        rhs = scope.process_expression_for_loopy(self.parse_expr(node, node.expr))
 
         self.add_expression_instruction(lhs, rhs)
 
@@ -425,9 +427,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
         raise NotImplementedError("save")
 
     def map_Line(self, node):
-        #from warnings import warn
-        #warn("Encountered a 'line': %s" % node)
-        raise NotImplementedError
+        pass
 
     def map_Program(self, node):
         raise NotImplementedError
@@ -467,7 +467,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
         cond_var = var(cond_name)
 
         self.add_expression_instruction(
-                cond_var, self.parse_expr(node.expr))
+                cond_var, self.parse_expr(node, node.expr))
 
         self.conditions.append(cond_name)
 
@@ -489,6 +489,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
             loop_var = loop_var.strip()
             scope.use_name(loop_var)
             loop_bounds = self.parse_expr(
+                    node,
                     loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS)
 
             if len(loop_bounds) == 2:
@@ -627,12 +628,16 @@ class F2LoopyTranslator(FTreeWalkerBase):
 
     # }}}
 
-    def make_kernels(self, pre_transform_code=None):
+    def make_kernels(self, pre_transform_code=None, pre_transform_code_context=None):
         kernel_names = [
                 sub.subprogram_name
                 for sub in self.kernels]
 
-        proc_dict = {}
+        if pre_transform_code_context is None:
+            proc_dict = {}
+        else:
+            proc_dict = pre_transform_code_context.copy()
+
         proc_dict["lp"] = lp
         proc_dict["np"] = np
 
diff --git a/loopy/frontend/fortran/tree.py b/loopy/frontend/fortran/tree.py
index 4291d9874..b1df6e3d0 100644
--- a/loopy/frontend/fortran/tree.py
+++ b/loopy/frontend/fortran/tree.py
@@ -24,6 +24,8 @@ THE SOFTWARE.
 
 import re
 
+from loopy.diagnostic import LoopyError
+
 
 class FTreeWalkerBase(object):
     def __init__(self):
@@ -53,13 +55,13 @@ class FTreeWalkerBase(object):
             r"^(?P<name>[_0-9a-zA-Z]+)"
             "(\((?P<shape>[-+*0-9:a-zA-Z, \t]+)\))?$")
 
-    def parse_dimension_specs(self, dim_decls):
+    def parse_dimension_specs(self, node, dim_decls):
         def parse_bounds(bounds_str):
             start_end = bounds_str.split(":")
 
             assert 1 <= len(start_end) <= 2
 
-            return [self.parse_expr(s) for s in start_end]
+            return [self.parse_expr(node, s) for s in start_end]
 
         for decl in dim_decls:
             entity_match = self.ENTITY_RE.match(decl)
@@ -81,8 +83,13 @@ class FTreeWalkerBase(object):
 
     # {{{ expressions
 
-    def parse_expr(self, expr_str, **kwargs):
-        return self.expr_parser(expr_str, **kwargs)
+    def parse_expr(self, node, expr_str, **kwargs):
+        try:
+            return self.expr_parser(expr_str, **kwargs)
+        except Exception as e:
+            raise LoopyError(
+                    "Error parsing expression '%s' on line %d of '%s': %s"
+                    % (expr_str, node.item.span[0], self.filename, str(e)))
 
     # }}}
 
diff --git a/setup.cfg b/setup.cfg
index 6faef2e65..d3f13a0e6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,3 +1,3 @@
 [flake8]
-ignore = E126,E127,E128,E123,E226,E241,E242,E265
+ignore = E126,E127,E128,E123,E226,E241,E242,E265,N802
 max-line-length=85
-- 
GitLab