From 53d0aacf76113fdf89a7631360e308349c00f5d2 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 14 May 2015 01:33:01 -0500
Subject: [PATCH] Implement loopy define blocks

---
 examples/fortran/foo.floopy          |  9 +++-
 loopy/frontend/fortran/__init__.py   | 74 ++++++++++++++++++++++++++--
 loopy/frontend/fortran/translator.py | 30 +++--------
 loopy/tools.py                       |  4 +-
 4 files changed, 87 insertions(+), 30 deletions(-)

diff --git a/examples/fortran/foo.floopy b/examples/fortran/foo.floopy
index 0db51ea86..9b0575607 100644
--- a/examples/fortran/foo.floopy
+++ b/examples/fortran/foo.floopy
@@ -1,14 +1,19 @@
+!$loopy begin define
+! define("factor 4.0")
+! define("real_type real*8")
+!$loopy end define
+
 subroutine fill(out, a, n)
   implicit none
 
-  real*8 a, out(n)
+  real_type a, out(n)
   integer n, i
 
   do i = 1, n
     out(i) = a
   end do
   do i = 1, n
-    out(i) = out(i) * 2
+    out(i) = out(i) * factor
   end do
 end
 
diff --git a/loopy/frontend/fortran/__init__.py b/loopy/frontend/fortran/__init__.py
index 1a610b63c..c6b6b3c71 100644
--- a/loopy/frontend/fortran/__init__.py
+++ b/loopy/frontend/fortran/__init__.py
@@ -25,10 +25,56 @@ THE SOFTWARE.
 from loopy.diagnostic import LoopyError
 
 
+def _extract_define_lines(source):
+    lines = source.split("\n")
+
+    import re
+    comment_re = re.compile(r"^\s*\!(.*)$")
+
+    remaining_lines = []
+    define_lines = []
+
+    in_define_code = False
+    for l in lines:
+        comment_match = comment_re.match(l)
+
+        if comment_match is None:
+            if in_define_code:
+                raise LoopyError("non-comment source line in define block")
+
+            remaining_lines.append(l)
+            continue
+
+        cmt = comment_match.group(1)
+        cmt_stripped = cmt.strip()
+
+        if cmt_stripped == "$loopy begin define":
+            if in_define_code:
+                raise LoopyError("can't enter transform code twice")
+            in_define_code = True
+
+        elif cmt_stripped == "$loopy end define":
+            if not in_define_code:
+                raise LoopyError("can't leave transform code twice")
+            in_define_code = False
+
+        elif in_define_code:
+            define_lines.append(cmt)
+
+        else:
+            remaining_lines.append(l)
+
+    return "\n".join(remaining_lines), "\n".join(define_lines)
+
+
 def f2loopy(source, free_form=True, strict=True,
-        pre_transform_code=None, pre_transform_code_context=None,
-        use_c_preprocessor=False,
+        pre_transform_code=None, transform_code_context=None,
+        use_c_preprocessor=False, preprocessor_defines=None,
         file_name="<floopy code>"):
+    """
+    :arg preprocessor_defines: a list of strings as they might occur after a
+        C-style ``#define`` directive, for example ``deg2rad(x) (x/180d0 * 3.14d0)``.
+    """
     if use_c_preprocessor:
         try:
             import ply.lex as lex
@@ -40,6 +86,28 @@ def f2loopy(source, free_form=True, strict=True,
 
         from ply.cpp import Preprocessor
         p = Preprocessor(lexer)
+
+        if preprocessor_defines:
+            for d in preprocessor_defines:
+                p.define(d)
+
+        source, define_code = _extract_define_lines(source)
+        if define_code is not None:
+            from loopy.tools import remove_common_indentation
+            define_code = remove_common_indentation(
+                    define_code,
+                    require_leading_newline=False)
+            def_dict = {}
+            def_dict["define"] = p.define
+
+            if pre_transform_code is not None:
+                def_dict["_MODULE_SOURCE_CODE"] = pre_transform_code
+                exec(compile(pre_transform_code,
+                    "<loopy pre-transform code>", "exec"), def_dict)
+
+            def_dict["_MODULE_SOURCE_CODE"] = define_code
+            exec(compile(define_code, "<loopy defines>", "exec"), def_dict)
+
         p.parse(source, file_name)
 
         tokens = []
@@ -65,6 +133,6 @@ def f2loopy(source, free_form=True, strict=True,
     f2loopy(tree)
 
     return f2loopy.make_kernels(pre_transform_code=pre_transform_code,
-            pre_transform_code_context=pre_transform_code_context)
+            transform_code_context=transform_code_context)
 
 # vim: foldmethod=marker
diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py
index 55e507dbc..7183d9740 100644
--- a/loopy/frontend/fortran/translator.py
+++ b/loopy/frontend/fortran/translator.py
@@ -194,24 +194,6 @@ class Scope(object):
 # }}}
 
 
-def remove_common_indentation(lines):
-    while lines and lines[0].strip() == "":
-        lines.pop(0)
-    while lines and lines[-1].strip() == "":
-        lines.pop(-1)
-
-    if lines:
-        base_indent = 0
-        while lines[0][base_indent] in " \t":
-            base_indent += 1
-
-        for line in lines[1:]:
-            if line[:base_indent].strip():
-                raise ValueError("inconsistent indentation")
-
-    return "\n".join(line[base_indent:] for line in lines)
-
-
 # {{{ translator
 
 class F2LoopyTranslator(FTreeWalkerBase):
@@ -641,15 +623,15 @@ class F2LoopyTranslator(FTreeWalkerBase):
 
     # }}}
 
-    def make_kernels(self, pre_transform_code=None, pre_transform_code_context=None):
+    def make_kernels(self, pre_transform_code=None, transform_code_context=None):
         kernel_names = [
                 sub.subprogram_name
                 for sub in self.kernels]
 
-        if pre_transform_code_context is None:
+        if transform_code_context is None:
             proc_dict = {}
         else:
-            proc_dict = pre_transform_code_context.copy()
+            proc_dict = transform_code_context.copy()
 
         proc_dict["lp"] = lp
         proc_dict["np"] = np
@@ -707,13 +689,15 @@ class F2LoopyTranslator(FTreeWalkerBase):
 
             proc_dict[sub.subprogram_name] = lp.fold_constants(knl)
 
+        from loopy.tools import remove_common_indentation
         transform_code = remove_common_indentation(
-                self.transform_code_lines)
+                "\n".join(self.transform_code_lines),
+                require_leading_newline=False)
 
         if pre_transform_code is not None:
             proc_dict["_MODULE_SOURCE_CODE"] = pre_transform_code
             exec(compile(pre_transform_code,
-                "<loopy transforms>", "exec"), proc_dict)
+                "<loopy pre-transform code>", "exec"), proc_dict)
 
         proc_dict["_MODULE_SOURCE_CODE"] = transform_code
         exec(compile(transform_code,
diff --git a/loopy/tools.py b/loopy/tools.py
index 1f4716067..b3c610dcb 100644
--- a/loopy/tools.py
+++ b/loopy/tools.py
@@ -164,14 +164,14 @@ class PicklableDtype(object):
 
 # {{{ remove common indentation
 
-def remove_common_indentation(code):
+def remove_common_indentation(code, require_leading_newline=True):
     if "\n" not in code:
         return code
 
     # accommodate pyopencl-ish syntax highlighting
     code = code.lstrip("//CL//")
 
-    if not code.startswith("\n"):
+    if require_leading_newline and not code.startswith("\n"):
         return code
 
     lines = code.split("\n")
-- 
GitLab