From a1f14785a7629c4ecf446ca9f105c552fca348a9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 1 Aug 2012 17:04:51 -0400
Subject: [PATCH] Introduce kernel templating helper.

---
 contrib/pyopencl.vim  |   4 +-
 pyopencl/algorithm.py | 142 +++++++++++++-----------------
 pyopencl/array.py     |   3 +
 pyopencl/compyte      |   2 +-
 pyopencl/scan.py      |  52 ++++++++---
 pyopencl/tools.py     | 197 +++++++++++++++++++++++++++++++++++++++---
 6 files changed, 291 insertions(+), 109 deletions(-)

diff --git a/contrib/pyopencl.vim b/contrib/pyopencl.vim
index d3dc8cb4..90d0eeff 100644
--- a/contrib/pyopencl.vim
+++ b/contrib/pyopencl.vim
@@ -66,11 +66,11 @@ hi link clmakoAttributeValue String
 " }}}
 
 syn region pythonCLString
-      \ start=+[uU]\=\z('''\|"""\)//CL//+ end="\z1" keepend
+      \ start=+[uU]\=\z('''\|"""\)//CL\(:[a-zA-Z_0-9]\+\)\?//+ end="\z1" keepend
       \ contains=@clCode,@clmakoCode
 
 syn region pythonCLRawString
-      \ start=+[uU]\=[rR]\z('''\|"""\)//CL//+ end="\z1" keepend
+      \ start=+[uU]\=[rR]\z('''\|"""\)//CL\(:[a-zA-Z_0-9]\+\)\?//+ end="\z1" keepend
       \ contains=@clCode,@clmakoCode
 
 " Uncomment if you still want the code highlighted as a string.
diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 204139a1..ac8770d1 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -30,7 +30,7 @@ OTHER DEALINGS IN THE SOFTWARE.
 import numpy as np
 import pyopencl as cl
 import pyopencl.array
-from pyopencl.scan import GenericScanKernel
+from pyopencl.scan import GenericScanKernel, ScanTemplate
 from pyopencl.tools import dtype_to_ctype
 from pyopencl.tools import context_dependent_memoize
 from pytools import memoize
@@ -41,28 +41,16 @@ from mako.template import Template
 
 # {{{ copy_if
 
-@context_dependent_memoize
-def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype,
-        extra_args_types, preamble):
-    ctype = dtype_to_ctype(dtype)
-    arguments = [
-        "__global %s *ary" % ctype,
-        "__global %s *out" % ctype,
-        "__global unsigned long *count",
-        ] + [
-                "%s %s" % (dtype_to_ctype(arg_dtype), name)
-                for name, arg_dtype in extra_args_types]
-
-    return GenericScanKernel(
-            ctx, dtype,
-            arguments=", ".join(arguments),
-            input_expr="(%s) ? 1 : 0" % predicate,
-            scan_expr="a+b", neutral="0",
-            output_statement="""
-                if (prev_item != item) out[item-1] = ary[i];
-                if (i+1 == N) *count = item;
-                """,
-            preamble=preamble)
+_copy_if_template = ScanTemplate(
+        arguments="item_t *ary, item_t *out, scan_t *count",
+        input_expr="(%(predicate)s) ? 1 : 0",
+        scan_expr="a+b", neutral="0",
+        output_statement="""
+            if (prev_item != item) out[item-1] = ary[i];
+            if (i+1 == N) *count = item;
+            """,
+        template_processor="printf")
+
 
 def copy_if(ary, predicate, extra_args=[], queue=None, preamble=""):
     """Copy the elements of *ary* satisfying *predicate* to an output array.
@@ -76,18 +64,20 @@ def copy_if(ary, predicate, extra_args=[], queue=None, preamble=""):
         is an on-device scalar (fetch to host with `count.get()`) indicating
         how many elements satisfied *predicate*.
     """
-    if len(ary) > np.iinfo(np.uint32).max:
-        scan_dtype = np.uint64
+    if len(ary) > np.iinfo(np.int32).max:
+        scan_dtype = np.int64
     else:
-        scan_dtype = np.uint32
+        scan_dtype = np.int32
 
-    extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
+    extra_args_types = tuple((val.dtype, name) for name, val in extra_args)
     extra_args_values = tuple(val for name, val in extra_args)
 
-    knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype,
-            extra_args_types, preamble=preamble)
+    knl = _copy_if_template.build(ary.context,
+            type_values=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
+            var_values=(("predicate", predicate),),
+            more_preamble=preamble, more_arguments=extra_args_types)
     out = cl.array.empty_like(ary)
-    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
+    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=scan_dtype)
     knl(ary, out, count, *extra_args_values, queue=queue)
     return out, count
 
@@ -114,32 +104,22 @@ def remove_if(ary, predicate, extra_args=[], queue=None, preamble=""):
 
 # {{{ partition
 
-@context_dependent_memoize
-def _get_partition_kernel(ctx, dtype, predicate, scan_dtype,
-        extra_args_types, preamble):
-    ctype = dtype_to_ctype(dtype)
-    arguments = [
-        "__global %s *ary" % ctype,
-        "__global %s *out_true" % ctype,
-        "__global %s *out_false" % ctype,
-        "__global unsigned long *count_true",
-        ] + [
-                "%s %s" % (dtype_to_ctype(arg_dtype), name)
-                for name, arg_dtype in extra_args_types]
-
-    return GenericScanKernel(
-            ctx, dtype,
-            arguments=", ".join(arguments),
-            input_expr="(%s) ? 1 : 0" % predicate,
-            scan_expr="a+b", neutral="0",
-            output_statement="""
+_partition_template = ScanTemplate(
+        arguments=(
+            "item_t *ary, item_t *out_true, item_t *out_false, "
+            "scan_t *count_true"),
+        input_expr="(%(predicate)s) ? 1 : 0",
+        scan_expr="a+b", neutral="0",
+        output_statement="""//CL//
                 if (prev_item != item)
                     out_true[item-1] = ary[i];
                 else
                     out_false[i-item] = ary[i];
                 if (i+1 == N) *count_true = item;
                 """,
-            preamble=preamble)
+        template_processor="printf")
+
+
 
 def partition(ary, predicate, extra_args=[], queue=None, preamble=""):
     """Copy the elements of *ary* into one of two arrays depending on whether
@@ -158,14 +138,18 @@ def partition(ary, predicate, extra_args=[], queue=None, preamble=""):
     else:
         scan_dtype = np.uint32
 
-    extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
+    extra_args_types = tuple((val.dtype, name) for name, val in extra_args)
     extra_args_values = tuple(val for name, val in extra_args)
 
-    knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype,
-            extra_args_types, preamble)
+    knl = _partition_template.build(
+            ary.context,
+            type_values=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
+            var_values=(("predicate", predicate),),
+            more_preamble=preamble, more_arguments=extra_args_types)
+
     out_true = cl.array.empty_like(ary)
     out_false = cl.array.empty_like(ary)
-    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
+    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=scan_dtype)
     knl(ary, out_true, out_false, count, *extra_args_values, queue=queue)
     return out_true, out_false, count
 
@@ -173,35 +157,21 @@ def partition(ary, predicate, extra_args=[], queue=None, preamble=""):
 
 # {{{ unique
 
-@context_dependent_memoize
-def _get_unique_kernel(ctx, dtype, is_equal_expr, scan_dtype,
-        extra_args_types, preamble):
-    ctype = dtype_to_ctype(dtype)
-    arguments = [
-        "__global %s *ary" % ctype,
-        "__global %s *out" % ctype,
-        "__global unsigned long *count_unique",
-        ] + [
-                "%s %s" % (dtype_to_ctype(arg_dtype), name)
-                for name, arg_dtype in extra_args_types]
-
-    from pyopencl.scan import _process_code_for_macro
-    key_expr_define = "#define IS_EQUAL_EXPR(a, b) %s\n" \
-            % _process_code_for_macro(is_equal_expr)
-    return GenericScanKernel(
-            ctx, dtype,
-            arguments=", ".join(arguments),
-            input_fetch_exprs=[
-                ("ary_im1", "ary", -1),
-                ("ary_i", "ary", 0),
-                ],
-            input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)",
-            scan_expr="a+b", neutral="0",
-            output_statement="""
+_unique_template = ScanTemplate(
+        arguments="item_t *ary, item_t *out, scan_t *count_unique",
+        input_fetch_exprs=[
+            ("ary_im1", "ary", -1),
+            ("ary_i", "ary", 0),
+            ],
+        input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)",
+        scan_expr="a+b", neutral="0",
+        output_statement="""
                 if (prev_item != item) out[item-1] = ary[i];
                 if (i+1 == N) *count_unique = item;
                 """,
-            preamble=preamble+"\n\n"+key_expr_define)
+        preamble="#define IS_EQUAL_EXPR(a, b) %(macro_is_equal_expr)s\n",
+        template_processor="printf")
+
 
 def unique(ary, is_equal_expr="a == b", extra_args=[], queue=None, preamble=""):
     """Copy the elements of *ary* into the output if *is_equal_expr*, applied to the
@@ -225,13 +195,17 @@ def unique(ary, is_equal_expr="a == b", extra_args=[], queue=None, preamble=""):
     else:
         scan_dtype = np.uint32
 
-    extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
+    extra_args_types = tuple((val.dtype, name) for name, val in extra_args)
     extra_args_values = tuple(val for name, val in extra_args)
 
-    knl = _get_unique_kernel(ary.context, ary.dtype, is_equal_expr, scan_dtype,
-            extra_args_types, preamble)
+    knl = _unique_template.build(
+            ary.context,
+            type_values=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
+            var_values=(("macro_is_equal_expr", is_equal_expr),),
+            more_preamble=preamble, more_arguments=extra_args_types)
+
     out = cl.array.empty_like(ary)
-    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
+    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=scan_dtype)
     knl(ary, out, count, *extra_args_values, queue=queue)
     return out, count
 
diff --git a/pyopencl/array.py b/pyopencl/array.py
index c06e21b2..d73214d5 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -61,6 +61,7 @@ def _create_vector_types():
 
     from pyopencl.tools import register_dtype
 
+    vec.types = {}
     counts = [2, 3, 4, 8, 16]
     for base_name, base_type in [
         ('char', np.int8),
@@ -99,6 +100,8 @@ def _create_vector_types():
                         % (my_field_names_defaulted, my_field_names),
                         dict(array=np.array, my_dtype=dtype))))
 
+            vec.types[np.dtype(base_type), count] = dtype
+
 _create_vector_types()
 
 # }}}
diff --git a/pyopencl/compyte b/pyopencl/compyte
index 5ab57e78..1e4f772e 160000
--- a/pyopencl/compyte
+++ b/pyopencl/compyte
@@ -1 +1 @@
-Subproject commit 5ab57e78f05b38ffc9d758b1762ae81f8dab69c3
+Subproject commit 1e4f772ef58883bdd4da06cb07de82feb77ea1bd
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 03bfef17..0ee565f2 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -33,7 +33,8 @@ import numpy as np
 
 import pyopencl as cl
 import pyopencl.array
-from pyopencl.tools import dtype_to_ctype, bitlog2
+from pyopencl.tools import (dtype_to_ctype, bitlog2,
+        KernelTemplateBase, _process_code_for_macro)
 import pyopencl._mymako as mako
 from pyopencl._cluda import CLUDA_PREAMBLE
 
@@ -663,15 +664,6 @@ void ${name_prefix}_final_update(
 
 # {{{ helpers
 
-def _process_code_for_macro(code):
-    code = code.replace("//CL//", "\n")
-
-    if "//" in code:
-        raise RuntimeError("end-of-line comments ('//') may not be used in "
-                "scan code snippets")
-
-    return code.replace("\n", " \\\n")
-
 def _round_down_to_power_of_2(val):
     result = 2**bitlog2(val)
     if result > val:
@@ -774,7 +766,7 @@ _IGNORED_WORDS = set("""
         use_lookbehind_update store_segment_start_flags
         update_loop first_seg
 
-        a b prev_item i prev_item_unavailable_with_local_update prev_value
+        a b prev_item i last_item prev_value
         N NO_SEG_BOUNDARY across_seg_boundary
         """.split())
 
@@ -1393,7 +1385,7 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
 
 # {{{ compatibility interface
 
-class _LegacyScanKernelBase(GenericScanKernel): # FIXME
+class _LegacyScanKernelBase(GenericScanKernel):
     def __init__(self, ctx, dtype,
             scan_expr, neutral=None,
             name_prefix="scan", options=[], preamble="", devices=None):
@@ -1443,4 +1435,40 @@ class ExclusiveScanKernel(_LegacyScanKernelBase):
 
 # }}}
 
+# {{{ template
+
+class ScanTemplate(KernelTemplateBase):
+    def __init__(self,
+            arguments, input_expr, scan_expr, neutral, output_statement,
+            is_segment_start_expr=None, input_fetch_exprs=[],
+            name_prefix="scan", preamble="", template_processor=None):
+
+        KernelTemplateBase.__init__(self, template_processor=template_processor)
+        self.arguments = arguments
+        self.input_expr = input_expr
+        self.scan_expr = scan_expr
+        self.neutral = neutral
+        self.output_statement = output_statement
+        self.is_segment_start_expr = is_segment_start_expr
+        self.input_fetch_exprs = input_fetch_exprs
+        self.name_prefix = name_prefix
+        self.preamble = preamble
+
+    def build_inner(self, context, type_values, var_values,
+            more_preamble="", more_arguments=(), declare_types=(),
+            options=(), devices=None, scan_cls=GenericScanKernel):
+        renderer = self.get_renderer(type_values, var_values, context, options)
+
+        return scan_cls(context, renderer.type_dict["scan_t"],
+            renderer.render_argument_list(self.arguments, more_arguments),
+            renderer(self.input_expr), renderer(self.scan_expr), renderer(self.neutral),
+            renderer(self.output_statement),
+            is_segment_start_expr=renderer(self.is_segment_start_expr),
+            input_fetch_exprs=self.input_fetch_exprs,
+            index_dtype=renderer.type_dict.get("index_t", np.int32),
+            name_prefix=renderer(self.name_prefix), options=list(options),
+            preamble=renderer(more_preamble+"\n"+self.preamble), devices=devices)
+
+# }}}
+
 # vim: filetype=pyopencl:fdm=marker
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index e63ca2f6..e82f8b63 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -32,7 +32,9 @@ OTHER DEALINGS IN THE SOFTWARE.
 import numpy as np
 from decorator import decorator
 import pyopencl as cl
-from pytools import memoize
+from pytools import memoize, memoize_method
+
+import re
 
 from pyopencl.compyte.dtypes import (
         register_dtype, _fill_dtype_registry,
@@ -54,7 +56,7 @@ MemoryPool = cl.MemoryPool
 
 
 
-first_arg_dependent_memoized_functions = []
+_first_arg_dependent_caches = []
 
 
 
@@ -79,7 +81,7 @@ def first_arg_dependent_memoize(func, cl_object, *args):
     try:
         return ctx_dict[cl_object][args]
     except KeyError:
-        first_arg_dependent_memoized_functions.append(func)
+        _first_arg_dependent_caches.append(ctx_dict)
         arg_dict = ctx_dict.setdefault(cl_object, {})
         result = func(cl_object, *args)
         arg_dict[args] = result
@@ -98,13 +100,8 @@ def clear_first_arg_caches():
 
     .. versionadded:: 2011.2
     """
-    for func in first_arg_dependent_memoized_functions:
-        try:
-            ctx_dict = func._pyopencl_first_arg_dep_memoize_dic
-        except AttributeError:
-            pass
-        else:
-            ctx_dict.clear()
+    for cache in _first_arg_dependent_caches:
+        cache.clear()
 
 import atexit
 atexit.register(clear_first_arg_caches)
@@ -494,4 +491,184 @@ def dtype_to_c_struct(device, dtype):
 
 
 
+# {{{ code generation/templating helper
+
+def _process_code_for_macro(code):
+    code = code.replace("//CL//", "\n")
+
+    if "//" in code:
+        raise RuntimeError("end-of-line comments ('//') may not be used in "
+                "code snippets")
+
+    return code.replace("\n", " \\\n")
+
+class _SimpleTextTemplate:
+    def __init__(self, txt):
+        self.txt = txt
+
+    def render(self, context):
+        return self.txt
+
+class _PrintfTextTemplate:
+    def __init__(self, txt):
+        self.txt = txt
+
+    def render(self, context):
+        return self.txt % context
+
+class _MakoTextTemplate:
+    def __init__(self, txt):
+        from mako.template import Template
+        self.template = Template(txt, strict_undefined=True)
+
+    def render(self, context):
+        return self.template.render(**context)
+
+
+
+
+class _ArgumentPlaceholder:
+    def __init__(self, typename, name):
+        self.typename = typename
+        self.name = name
+
+    def to_arg(self, type_dict):
+        if isinstance(self.typename, str):
+            try:
+                dtype = type_dict[self.typename]
+            except KeyError:
+                from pyopencl.compyte.dtypes import NAME_TO_DTYPE
+                dtype = NAME_TO_DTYPE[self.typename]
+        else:
+            dtype = np.dtype(self.typename)
+
+        return self.target_class(dtype, self.name)
+
+class _VectorArgPlaceholder(_ArgumentPlaceholder):
+    target_class = VectorArg
+
+class _ScalarArgPlaceholder(_ArgumentPlaceholder):
+    target_class = ScalarArg
+
+
+
+
+class _TemplateRenderer(object):
+    def __init__(self, template, type_values, var_values, context=None, options=[]):
+        self.template = template
+        self.type_dict = dict(type_values)
+        self.var_dict = dict(var_values)
+
+        for name in self.var_dict:
+            if name.startswith("macro_"):
+                self.var_dict[name] = _process_code_for_macro(self.var_dict[name])
+
+        self.context = context
+        self.options = options
+
+    def __call__(self, txt):
+        if txt is None:
+            return txt
+
+        result = self.template.get_text_template(txt).render(self.var_dict)
+
+        # substitute in types
+        for name, dtype in self.type_dict.iteritems():
+            result = re.sub(r"\b%s\b" % name, dtype_to_ctype(dtype), result)
+
+        return result
+
+    def get_rendered_kernel(self, txt, kernel_name):
+        prg = cl.Program(self.context, self(txt)).build(self.options)
+
+        kernel_name_prefix = self.var_dict.get("kernel_name_prefix")
+        if kernel_name_prefix is not None:
+            kernel_name = kernel_name_prefix+kernel_name
+
+        return getattr(prg, kernel_name)
+
+    def render_argument_list(self, arguments, more_arguments):
+        all_args = []
+
+        if isinstance(arguments, str):
+            all_args.extend(arguments.split(","))
+        else:
+            all_args.extend(arguments)
+
+        if isinstance(more_arguments, str):
+            all_args.extend(more_arguments.split(","))
+        else:
+            all_args.extend(more_arguments)
+
+        from pyopencl.compyte.dtypes import parse_c_arg_backend
+        parsed_args = []
+        for arg in all_args:
+            if isinstance(arg, str):
+                ph = parse_c_arg_backend(arg,
+                        _ScalarArgPlaceholder, _VectorArgPlaceholder,
+                        name_to_dtype=lambda x: x)
+                parsed_arg = ph.to_arg(self.type_dict)
+            elif isinstance(arg, Argument):
+                parsed_arg = arg
+            elif isinstance(arg, tuple):
+                ph = _ScalarArgPlaceholder(arg[0], arg[1])
+                parsed_arg = ph.to_arg(self.type_dict)
+
+            parsed_args.append(parsed_arg)
+
+        return parsed_args
+
+
+
+
+class KernelTemplateBase(object):
+    def __init__(self, template_processor=None):
+        self.template_processor = template_processor
+
+        self.build_cache = {}
+        _first_arg_dependent_caches.append(self.build_cache)
+
+    def get_preamble(self):
+        pass
+
+    _TEMPLATE_PROCESSOR_PATTERN = re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//")
+
+    @memoize_method
+    def get_text_template(self, txt):
+        proc_match = self._TEMPLATE_PROCESSOR_PATTERN.match(txt)
+        tpl_processor = None
+
+        chop_first = 0
+        if proc_match is not None:
+            tpl_processor = proc_match.group(1)
+            # chop off //CL// mark
+            txt = txt[len(proc_match.group(0)):]
+        if tpl_processor is None:
+            tpl_processor = self.template_processor
+
+        if tpl_processor is None or tpl_processor == "none":
+            return _SimpleTextTemplate(txt)
+        elif tpl_processor == "printf":
+            return _PrintfTextTemplate(txt)
+        elif tpl_processor == "mako":
+            return _MakoTextTemplate(txt)
+        else:
+            raise RuntimeError("unknown template processor '%s'" % proc_match.group(1))
+
+    def get_renderer(self, type_values, var_values, context=None, options=[]):
+        return _TemplateRenderer(self, type_values, var_values)
+
+    def build(self, context, *args, **kwargs):
+        """Provide caching for an :meth:`build_inner`."""
+
+        cache_key = (context, args, tuple(sorted(kwargs.iteritems())))
+        try:
+            return self.build_cache[cache_key]
+        except KeyError:
+            result = self.build_inner(context, *args, **kwargs)
+            self.build_cache[cache_key] = result
+            return result
+
+# }}}
+
 # vim: foldmethod=marker
-- 
GitLab