From f1012ba21a9b14d361619e7f362c9424e4333be9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 3 May 2015 22:08:22 -0500
Subject: [PATCH] Refactor code generation so that generating pure C function
 bodies is possible

---
 loopy/codegen/__init__.py            | 111 ++++++++++-----------------
 loopy/kernel/data.py                 |  31 +-------
 loopy/target/c/__init__.py           |  55 +++++++++++++
 loopy/target/c/codegen/expression.py |  23 +-----
 loopy/target/opencl/__init__.py      |  37 +++++++++
 loopy/tools.py                       |  36 +++++++++
 6 files changed, 174 insertions(+), 119 deletions(-)

diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 57393bd6d..ecbb55203 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -1,6 +1,4 @@
-from __future__ import division
-from __future__ import absolute_import
-import six
+from __future__ import division, absolute_import
 
 __copyright__ = "Copyright (C) 2012 Andreas Kloeckner"
 
@@ -24,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+import six
 
 from loopy.diagnostic import LoopyError
 from pytools import Record
@@ -146,6 +145,26 @@ def add_comment(cmt, code):
 # }}}
 
 
+class SeenFunction(Record):
+    """
+    .. attribute:: name
+    .. attribute:: c_name
+    .. attribute:: arg_dtypes
+
+        a tuple of arg dtypes
+    """
+
+    def __init__(self, name, c_name, arg_dtypes):
+        Record.__init__(self,
+                name=name,
+                c_name=c_name,
+                arg_dtypes=arg_dtypes)
+
+    def __hash__(self):
+        return hash((type(self),)
+                + tuple((f, getattr(self, f)) for f in type(self).fields))
+
+
 # {{{ code generation state
 
 class CodeGenerationState(object):
@@ -371,30 +390,13 @@ def generate_code(kernel, device=None):
     from loopy.check import pre_codegen_checks
     pre_codegen_checks(kernel)
 
-    from cgen import (FunctionBody, FunctionDeclaration,
-            Value, Module, Block,
-            Line, Const, LiteralLines, Initializer)
-
     logger.info("%s: generate code: start" % kernel.name)
 
-    from cgen.opencl import (CLKernel, CLRequiredWorkGroupSize)
-
-    allow_complex = False
-    for var in kernel.args + list(six.itervalues(kernel.temporary_variables)):
-        if var.dtype.kind == "c":
-            allow_complex = True
-
-    mod = []
-
-    seen_dtypes = set()
-    seen_functions = set()
-
-    body = Block()
-
     # {{{ examine arg list
 
-    from loopy.kernel.data import ImageArg, ValueArg
+    from loopy.kernel.data import ValueArg
     from loopy.kernel.array import ArrayBase
+    from cgen import Const
 
     impl_arg_info = []
 
@@ -417,30 +419,15 @@ def generate_code(kernel, device=None):
         else:
             raise ValueError("argument type not understood: '%s'" % type(arg))
 
-    if any(isinstance(arg, ImageArg) for arg in kernel.args):
-        body.append(Initializer(Const(Value("sampler_t", "loopy_sampler")),
-            "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP "
-                "| CLK_FILTER_NEAREST"))
+    allow_complex = False
+    for var in kernel.args + list(six.itervalues(kernel.temporary_variables)):
+        if var.dtype.kind == "c":
+            allow_complex = True
 
     # }}}
 
-    mod.extend([
-        LiteralLines(r"""
-        #define lid(N) ((%(idx_ctype)s) get_local_id(N))
-        #define gid(N) ((%(idx_ctype)s) get_group_id(N))
-        """ % dict(idx_ctype=kernel.target.dtype_to_typename(kernel.index_dtype))),
-        Line()])
-
-    # {{{ build lmem array declarators for temporary variables
-
-    body.extend(
-            idi.cgen_declarator
-            for tv in six.itervalues(kernel.temporary_variables)
-            for idi in tv.decl_info(
-                kernel.target,
-                is_written=True, index_dtype=kernel.index_dtype))
-
-    # }}}
+    seen_dtypes = set()
+    seen_functions = set()
 
     initial_implemented_domain = isl.BasicSet.from_params(kernel.assumptions)
     codegen_state = CodeGenerationState(
@@ -449,24 +436,12 @@ def generate_code(kernel, device=None):
             expression_to_code_mapper=kernel.target.get_expression_to_code_mapper(
                 kernel, seen_dtypes, seen_functions, allow_complex))
 
-    from loopy.codegen.loop import set_up_hw_parallel_loops
-    gen_code = set_up_hw_parallel_loops(kernel, 0, codegen_state)
-
-    body.append(Line())
-
-    if isinstance(gen_code.ast, Block):
-        body.extend(gen_code.ast.contents)
-    else:
-        body.append(gen_code.ast)
+    code_str, implemented_domains = kernel.target.generate_code(
+            kernel, codegen_state, impl_arg_info)
 
-    mod.append(
-        FunctionBody(
-            CLRequiredWorkGroupSize(
-                kernel.get_grid_sizes_as_exprs()[1],
-                CLKernel(FunctionDeclaration(
-                    Value("void", kernel.name),
-                    [iai.cgen_declarator for iai in impl_arg_info]))),
-            body))
+    from loopy.check import check_implemented_domains
+    assert check_implemented_domains(kernel, implemented_domains,
+            code_str)
 
     # {{{ handle preambles
 
@@ -492,20 +467,18 @@ def generate_code(kernel, device=None):
         seen_preamble_tags.add(tag)
         dedup_preambles.append(preamble)
 
-    mod = ([LiteralLines(lines) for lines in dedup_preambles]
-            + [Line()] + mod)
-
-    # }}}
+    from loopy.tools import remove_common_indentation
+    preamble_codes = [
+            remove_common_indentation(lines) + "\n"
+            for lines in dedup_preambles]
 
-    result = str(Module(mod))
+    code_str = "".join(preamble_codes) + code_str
 
-    from loopy.check import check_implemented_domains
-    assert check_implemented_domains(kernel, gen_code.implemented_domains,
-            result)
+    # }}}
 
     logger.info("%s: generate code: done" % kernel.name)
 
-    result = result, impl_arg_info
+    result = code_str, impl_arg_info
 
     if CACHING_ENABLED:
         code_gen_cache[input_kernel] = result
diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index ab747b9a6..32c062689 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -747,34 +747,6 @@ class ExpressionInstruction(InstructionBase):
 # }}}
 
 
-def _remove_common_indentation(code):
-    if "\n" not in code:
-        return code
-
-    # accommodate pyopencl-ish syntax highlighting
-    code = code.lstrip("//CL//")
-
-    if not code.startswith("\n"):
-        return code
-
-    lines = code.split("\n")
-    while lines[0].strip() == "":
-        lines.pop(0)
-    while lines[-1].strip() == "":
-        lines.pop(-1)
-
-    if lines:
-        base_indent = 0
-        while lines[0][base_indent] in " \t":
-            base_indent += 1
-
-        for line in lines[1:]:
-            if line[:base_indent].strip():
-                raise ValueError("inconsistent indentation")
-
-    return "\n".join(line[base_indent:] for line in lines)
-
-
 # {{{ c instruction
 
 class CInstruction(InstructionBase):
@@ -873,7 +845,8 @@ class CInstruction(InstructionBase):
         # }}}
 
         self.iname_exprs = new_iname_exprs
-        self.code = _remove_common_indentation(code)
+        from loopy.tools import remove_common_indentation
+        self.code = remove_common_indentation(code)
         self.read_variables = read_variables
         self.assignees = new_assignees
 
diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
index bb1427739..f74485e05 100644
--- a/loopy/target/c/__init__.py
+++ b/loopy/target/c/__init__.py
@@ -24,6 +24,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+import six
+
 import numpy as np  # noqa
 from loopy.target import TargetBase
 
@@ -56,3 +58,56 @@ class CTarget(TargetBase):
         from loopy.target.c.codegen.expression import LoopyCCodeMapper
         return (LoopyCCodeMapper(kernel, seen_dtypes, seen_functions,
             allow_complex=allow_complex))
+
+    # {{{ code generation
+
+    def generate_code(self, kernel, codegen_state, impl_arg_info):
+        from cgen import FunctionBody, FunctionDeclaration, Value, Module
+
+        body, implemented_domains = kernel.target.generate_body(
+                kernel, codegen_state)
+
+        mod = Module([
+            FunctionBody(
+                kernel.target.wrap_function_declaration(
+                    kernel,
+                    FunctionDeclaration(
+                        Value("void", kernel.name),
+                        [iai.cgen_declarator for iai in impl_arg_info])),
+                body)
+            ])
+
+        return str(mod), implemented_domains
+
+    def wrap_function_declaration(self, kernel, fdecl):
+        return fdecl
+
+    def generate_body(self, kernel, codegen_state):
+        from cgen import Block
+        body = Block()
+
+        # {{{ declare temporaries
+
+        body.extend(
+                idi.cgen_declarator
+                for tv in six.itervalues(kernel.temporary_variables)
+                for idi in tv.decl_info(
+                    kernel.target,
+                    is_written=True, index_dtype=kernel.index_dtype))
+
+        # }}}
+
+        from loopy.codegen.loop import set_up_hw_parallel_loops
+        gen_code = set_up_hw_parallel_loops(kernel, 0, codegen_state)
+
+        from cgen import Line
+        body.append(Line())
+
+        if isinstance(gen_code.ast, Block):
+            body.extend(gen_code.ast.contents)
+        else:
+            body.append(gen_code.ast)
+
+        return body, gen_code.implemented_domains
+
+    # }}}
diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py
index ccfa88040..1dc3aafcf 100644
--- a/loopy/target/c/codegen/expression.py
+++ b/loopy/target/c/codegen/expression.py
@@ -32,7 +32,6 @@ from pymbolic.mapper import RecursiveMapper
 from pymbolic.mapper.stringifier import (PREC_NONE, PREC_CALL, PREC_PRODUCT,
         PREC_POWER)
 import islpy as isl
-from pytools import Record
 
 from loopy.expression import dtype_to_type_context, TypeInferenceMapper
 
@@ -47,26 +46,6 @@ def get_opencl_vec_member(idx):
     return "s%s" % hex(int(idx))[2:]
 
 
-class SeenFunction(Record):
-    """
-    .. attribute:: name
-    .. attribute:: c_name
-    .. attribute:: arg_dtypes
-
-        a tuple of arg dtypes
-    """
-
-    def __init__(self, name, c_name, arg_dtypes):
-        Record.__init__(self,
-                name=name,
-                c_name=c_name,
-                arg_dtypes=arg_dtypes)
-
-    def __hash__(self):
-        return hash((type(self),)
-                + tuple((f, getattr(self, f)) for f in type(self).fields))
-
-
 # {{{ C code mapper
 
 class LoopyCCodeMapper(RecursiveMapper):
@@ -330,6 +309,7 @@ class LoopyCCodeMapper(RecursiveMapper):
 
         def seen_func(name):
             idt = self.kernel.index_dtype
+            from loopy.codegen import SeenFunction
             self.seen_functions.add(SeenFunction(name, name, (idt, idt)))
 
         if den_nonneg:
@@ -451,6 +431,7 @@ class LoopyCCodeMapper(RecursiveMapper):
                         "for function '%s' not understood"
                         % identifier)
 
+        from loopy.codegen import SeenFunction
         self.seen_functions.add(SeenFunction(identifier, c_name, par_dtypes))
         if str_parameters is None:
             # /!\ FIXME For some functions (e.g. 'sin'), it makes sense to
diff --git a/loopy/target/opencl/__init__.py b/loopy/target/opencl/__init__.py
index b297eb9ff..e68040c09 100644
--- a/loopy/target/opencl/__init__.py
+++ b/loopy/target/opencl/__init__.py
@@ -237,6 +237,43 @@ class OpenCLTarget(CTarget):
     def get_vector_dtype(self, base, count):
         return vec.types[base, count]
 
+    def wrap_function_declaration(self, kernel, fdecl):
+        from cgen.opencl import CLKernel, CLRequiredWorkGroupSize
+        return CLRequiredWorkGroupSize(
+                kernel.get_grid_sizes_as_exprs()[1],
+                CLKernel(fdecl))
+
+    def generate_code(self, kernel, codegen_state, impl_arg_info):
+        code, implemented_domains = (
+                super(OpenCLTarget, self).generate_code(
+                    kernel, codegen_state, impl_arg_info))
+
+        from loopy.tools import remove_common_indentation
+        code = (
+                remove_common_indentation("""
+                    #define lid(N) ((%(idx_ctype)s) get_local_id(N))
+                    #define gid(N) ((%(idx_ctype)s) get_group_id(N))
+                    """ % dict(idx_ctype=self.dtype_to_typename(kernel.index_dtype)))
+                + "\n\n"
+                + code)
+
+        return code, implemented_domains
+
+    def generate_body(self, kernel, codegen_state):
+        body, implemented_domains = (
+                super(OpenCLTarget, self).generate_body(kernel, codegen_state))
+
+        from loopy.kernel.data import ImageArg
+
+        if any(isinstance(arg, ImageArg) for arg in kernel.args):
+            from cgen import Value, Const, Initializer
+            body.contents.insert(0,
+                    Initializer(Const(Value("sampler_t", "loopy_sampler")),
+                        "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP "
+                        "| CLK_FILTER_NEAREST"))
+
+        return body, implemented_domains
+
 # }}}
 
 # vim: foldmethod=marker
diff --git a/loopy/tools.py b/loopy/tools.py
index 0c8898bb1..1f4716067 100644
--- a/loopy/tools.py
+++ b/loopy/tools.py
@@ -98,6 +98,8 @@ class LoopyKeyBuilder(KeyBuilderBase):
 # }}}
 
 
+# {{{ picklable dtype
+
 class PicklableDtype(object):
     """This object works around several issues with pickling :class:`numpy.dtype`
     objects. It does so by serving as a picklable wrapper around the original
@@ -157,4 +159,38 @@ class PicklableDtype(object):
     def assert_has_target(self):
         assert self.target is not None
 
+# }}}
+
+
+# {{{ remove common indentation
+
+def remove_common_indentation(code):
+    if "\n" not in code:
+        return code
+
+    # accommodate pyopencl-ish syntax highlighting
+    code = code.lstrip("//CL//")
+
+    if not code.startswith("\n"):
+        return code
+
+    lines = code.split("\n")
+    while lines[0].strip() == "":
+        lines.pop(0)
+    while lines[-1].strip() == "":
+        lines.pop(-1)
+
+    if lines:
+        base_indent = 0
+        while lines[0][base_indent] in " \t":
+            base_indent += 1
+
+        for line in lines[1:]:
+            if line[:base_indent].strip():
+                raise ValueError("inconsistent indentation")
+
+    return "\n".join(line[base_indent:] for line in lines)
+
+# }}}
+
 # vim: foldmethod=marker
-- 
GitLab