From a1f14785a7629c4ecf446ca9f105c552fca348a9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 1 Aug 2012 17:04:51 -0400 Subject: [PATCH] Introduce kernel templating helper. --- contrib/pyopencl.vim | 4 +- pyopencl/algorithm.py | 142 +++++++++++++----------------- pyopencl/array.py | 3 + pyopencl/compyte | 2 +- pyopencl/scan.py | 52 ++++++++--- pyopencl/tools.py | 197 +++++++++++++++++++++++++++++++++++++++--- 6 files changed, 291 insertions(+), 109 deletions(-) diff --git a/contrib/pyopencl.vim b/contrib/pyopencl.vim index d3dc8cb4..90d0eeff 100644 --- a/contrib/pyopencl.vim +++ b/contrib/pyopencl.vim @@ -66,11 +66,11 @@ hi link clmakoAttributeValue String " }}} syn region pythonCLString - \ start=+[uU]\=\z('''\|"""\)//CL//+ end="\z1" keepend + \ start=+[uU]\=\z('''\|"""\)//CL\(:[a-zA-Z_0-9]\+\)\?//+ end="\z1" keepend \ contains=@clCode,@clmakoCode syn region pythonCLRawString - \ start=+[uU]\=[rR]\z('''\|"""\)//CL//+ end="\z1" keepend + \ start=+[uU]\=[rR]\z('''\|"""\)//CL\(:[a-zA-Z_0-9]\+\)\?//+ end="\z1" keepend \ contains=@clCode,@clmakoCode " Uncomment if you still want the code highlighted as a string. diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 204139a1..ac8770d1 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -30,7 +30,7 @@ OTHER DEALINGS IN THE SOFTWARE. import numpy as np import pyopencl as cl import pyopencl.array -from pyopencl.scan import GenericScanKernel +from pyopencl.scan import GenericScanKernel, ScanTemplate from pyopencl.tools import dtype_to_ctype from pyopencl.tools import context_dependent_memoize from pytools import memoize @@ -41,28 +41,16 @@ from mako.template import Template # {{{ copy_if -@context_dependent_memoize -def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, - extra_args_types, preamble): - ctype = dtype_to_ctype(dtype) - arguments = [ - "__global %s *ary" % ctype, - "__global %s *out" % ctype, - "__global unsigned long *count", - ] + [ - "%s %s" % (dtype_to_ctype(arg_dtype), name) - for name, arg_dtype in extra_args_types] - - return GenericScanKernel( - ctx, dtype, - arguments=", ".join(arguments), - input_expr="(%s) ? 1 : 0" % predicate, - scan_expr="a+b", neutral="0", - output_statement=""" - if (prev_item != item) out[item-1] = ary[i]; - if (i+1 == N) *count = item; - """, - preamble=preamble) +_copy_if_template = ScanTemplate( + arguments="item_t *ary, item_t *out, scan_t *count", + input_expr="(%(predicate)s) ? 1 : 0", + scan_expr="a+b", neutral="0", + output_statement=""" + if (prev_item != item) out[item-1] = ary[i]; + if (i+1 == N) *count = item; + """, + template_processor="printf") + def copy_if(ary, predicate, extra_args=[], queue=None, preamble=""): """Copy the elements of *ary* satisfying *predicate* to an output array. @@ -76,18 +64,20 @@ def copy_if(ary, predicate, extra_args=[], queue=None, preamble=""): is an on-device scalar (fetch to host with `count.get()`) indicating how many elements satisfied *predicate*. """ - if len(ary) > np.iinfo(np.uint32).max: - scan_dtype = np.uint64 + if len(ary) > np.iinfo(np.int32).max: + scan_dtype = np.int64 else: - scan_dtype = np.uint32 + scan_dtype = np.int32 - extra_args_types = tuple((name, val.dtype) for name, val in extra_args) + extra_args_types = tuple((val.dtype, name) for name, val in extra_args) extra_args_values = tuple(val for name, val in extra_args) - knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype, - extra_args_types, preamble=preamble) + knl = _copy_if_template.build(ary.context, + type_values=(("scan_t", scan_dtype), ("item_t", ary.dtype)), + var_values=(("predicate", predicate),), + more_preamble=preamble, more_arguments=extra_args_types) out = cl.array.empty_like(ary) - count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) + count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=scan_dtype) knl(ary, out, count, *extra_args_values, queue=queue) return out, count @@ -114,32 +104,22 @@ def remove_if(ary, predicate, extra_args=[], queue=None, preamble=""): # {{{ partition -@context_dependent_memoize -def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, - extra_args_types, preamble): - ctype = dtype_to_ctype(dtype) - arguments = [ - "__global %s *ary" % ctype, - "__global %s *out_true" % ctype, - "__global %s *out_false" % ctype, - "__global unsigned long *count_true", - ] + [ - "%s %s" % (dtype_to_ctype(arg_dtype), name) - for name, arg_dtype in extra_args_types] - - return GenericScanKernel( - ctx, dtype, - arguments=", ".join(arguments), - input_expr="(%s) ? 1 : 0" % predicate, - scan_expr="a+b", neutral="0", - output_statement=""" +_partition_template = ScanTemplate( + arguments=( + "item_t *ary, item_t *out_true, item_t *out_false, " + "scan_t *count_true"), + input_expr="(%(predicate)s) ? 1 : 0", + scan_expr="a+b", neutral="0", + output_statement="""//CL// if (prev_item != item) out_true[item-1] = ary[i]; else out_false[i-item] = ary[i]; if (i+1 == N) *count_true = item; """, - preamble=preamble) + template_processor="printf") + + def partition(ary, predicate, extra_args=[], queue=None, preamble=""): """Copy the elements of *ary* into one of two arrays depending on whether @@ -158,14 +138,18 @@ def partition(ary, predicate, extra_args=[], queue=None, preamble=""): else: scan_dtype = np.uint32 - extra_args_types = tuple((name, val.dtype) for name, val in extra_args) + extra_args_types = tuple((val.dtype, name) for name, val in extra_args) extra_args_values = tuple(val for name, val in extra_args) - knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype, - extra_args_types, preamble) + knl = _partition_template.build( + ary.context, + type_values=(("item_t", ary.dtype), ("scan_t", scan_dtype)), + var_values=(("predicate", predicate),), + more_preamble=preamble, more_arguments=extra_args_types) + out_true = cl.array.empty_like(ary) out_false = cl.array.empty_like(ary) - count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) + count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=scan_dtype) knl(ary, out_true, out_false, count, *extra_args_values, queue=queue) return out_true, out_false, count @@ -173,35 +157,21 @@ def partition(ary, predicate, extra_args=[], queue=None, preamble=""): # {{{ unique -@context_dependent_memoize -def _get_unique_kernel(ctx, dtype, is_equal_expr, scan_dtype, - extra_args_types, preamble): - ctype = dtype_to_ctype(dtype) - arguments = [ - "__global %s *ary" % ctype, - "__global %s *out" % ctype, - "__global unsigned long *count_unique", - ] + [ - "%s %s" % (dtype_to_ctype(arg_dtype), name) - for name, arg_dtype in extra_args_types] - - from pyopencl.scan import _process_code_for_macro - key_expr_define = "#define IS_EQUAL_EXPR(a, b) %s\n" \ - % _process_code_for_macro(is_equal_expr) - return GenericScanKernel( - ctx, dtype, - arguments=", ".join(arguments), - input_fetch_exprs=[ - ("ary_im1", "ary", -1), - ("ary_i", "ary", 0), - ], - input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)", - scan_expr="a+b", neutral="0", - output_statement=""" +_unique_template = ScanTemplate( + arguments="item_t *ary, item_t *out, scan_t *count_unique", + input_fetch_exprs=[ + ("ary_im1", "ary", -1), + ("ary_i", "ary", 0), + ], + input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)", + scan_expr="a+b", neutral="0", + output_statement=""" if (prev_item != item) out[item-1] = ary[i]; if (i+1 == N) *count_unique = item; """, - preamble=preamble+"\n\n"+key_expr_define) + preamble="#define IS_EQUAL_EXPR(a, b) %(macro_is_equal_expr)s\n", + template_processor="printf") + def unique(ary, is_equal_expr="a == b", extra_args=[], queue=None, preamble=""): """Copy the elements of *ary* into the output if *is_equal_expr*, applied to the @@ -225,13 +195,17 @@ def unique(ary, is_equal_expr="a == b", extra_args=[], queue=None, preamble=""): else: scan_dtype = np.uint32 - extra_args_types = tuple((name, val.dtype) for name, val in extra_args) + extra_args_types = tuple((val.dtype, name) for name, val in extra_args) extra_args_values = tuple(val for name, val in extra_args) - knl = _get_unique_kernel(ary.context, ary.dtype, is_equal_expr, scan_dtype, - extra_args_types, preamble) + knl = _unique_template.build( + ary.context, + type_values=(("item_t", ary.dtype), ("scan_t", scan_dtype)), + var_values=(("macro_is_equal_expr", is_equal_expr),), + more_preamble=preamble, more_arguments=extra_args_types) + out = cl.array.empty_like(ary) - count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) + count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=scan_dtype) knl(ary, out, count, *extra_args_values, queue=queue) return out, count diff --git a/pyopencl/array.py b/pyopencl/array.py index c06e21b2..d73214d5 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -61,6 +61,7 @@ def _create_vector_types(): from pyopencl.tools import register_dtype + vec.types = {} counts = [2, 3, 4, 8, 16] for base_name, base_type in [ ('char', np.int8), @@ -99,6 +100,8 @@ def _create_vector_types(): % (my_field_names_defaulted, my_field_names), dict(array=np.array, my_dtype=dtype)))) + vec.types[np.dtype(base_type), count] = dtype + _create_vector_types() # }}} diff --git a/pyopencl/compyte b/pyopencl/compyte index 5ab57e78..1e4f772e 160000 --- a/pyopencl/compyte +++ b/pyopencl/compyte @@ -1 +1 @@ -Subproject commit 5ab57e78f05b38ffc9d758b1762ae81f8dab69c3 +Subproject commit 1e4f772ef58883bdd4da06cb07de82feb77ea1bd diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 03bfef17..0ee565f2 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -33,7 +33,8 @@ import numpy as np import pyopencl as cl import pyopencl.array -from pyopencl.tools import dtype_to_ctype, bitlog2 +from pyopencl.tools import (dtype_to_ctype, bitlog2, + KernelTemplateBase, _process_code_for_macro) import pyopencl._mymako as mako from pyopencl._cluda import CLUDA_PREAMBLE @@ -663,15 +664,6 @@ void ${name_prefix}_final_update( # {{{ helpers -def _process_code_for_macro(code): - code = code.replace("//CL//", "\n") - - if "//" in code: - raise RuntimeError("end-of-line comments ('//') may not be used in " - "scan code snippets") - - return code.replace("\n", " \\\n") - def _round_down_to_power_of_2(val): result = 2**bitlog2(val) if result > val: @@ -774,7 +766,7 @@ _IGNORED_WORDS = set(""" use_lookbehind_update store_segment_start_flags update_loop first_seg - a b prev_item i prev_item_unavailable_with_local_update prev_value + a b prev_item i last_item prev_value N NO_SEG_BOUNDARY across_seg_boundary """.split()) @@ -1393,7 +1385,7 @@ class GenericDebugScanKernel(_GenericScanKernelBase): # {{{ compatibility interface -class _LegacyScanKernelBase(GenericScanKernel): # FIXME +class _LegacyScanKernelBase(GenericScanKernel): def __init__(self, ctx, dtype, scan_expr, neutral=None, name_prefix="scan", options=[], preamble="", devices=None): @@ -1443,4 +1435,40 @@ class ExclusiveScanKernel(_LegacyScanKernelBase): # }}} +# {{{ template + +class ScanTemplate(KernelTemplateBase): + def __init__(self, + arguments, input_expr, scan_expr, neutral, output_statement, + is_segment_start_expr=None, input_fetch_exprs=[], + name_prefix="scan", preamble="", template_processor=None): + + KernelTemplateBase.__init__(self, template_processor=template_processor) + self.arguments = arguments + self.input_expr = input_expr + self.scan_expr = scan_expr + self.neutral = neutral + self.output_statement = output_statement + self.is_segment_start_expr = is_segment_start_expr + self.input_fetch_exprs = input_fetch_exprs + self.name_prefix = name_prefix + self.preamble = preamble + + def build_inner(self, context, type_values, var_values, + more_preamble="", more_arguments=(), declare_types=(), + options=(), devices=None, scan_cls=GenericScanKernel): + renderer = self.get_renderer(type_values, var_values, context, options) + + return scan_cls(context, renderer.type_dict["scan_t"], + renderer.render_argument_list(self.arguments, more_arguments), + renderer(self.input_expr), renderer(self.scan_expr), renderer(self.neutral), + renderer(self.output_statement), + is_segment_start_expr=renderer(self.is_segment_start_expr), + input_fetch_exprs=self.input_fetch_exprs, + index_dtype=renderer.type_dict.get("index_t", np.int32), + name_prefix=renderer(self.name_prefix), options=list(options), + preamble=renderer(more_preamble+"\n"+self.preamble), devices=devices) + +# }}} + # vim: filetype=pyopencl:fdm=marker diff --git a/pyopencl/tools.py b/pyopencl/tools.py index e63ca2f6..e82f8b63 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -32,7 +32,9 @@ OTHER DEALINGS IN THE SOFTWARE. import numpy as np from decorator import decorator import pyopencl as cl -from pytools import memoize +from pytools import memoize, memoize_method + +import re from pyopencl.compyte.dtypes import ( register_dtype, _fill_dtype_registry, @@ -54,7 +56,7 @@ MemoryPool = cl.MemoryPool -first_arg_dependent_memoized_functions = [] +_first_arg_dependent_caches = [] @@ -79,7 +81,7 @@ def first_arg_dependent_memoize(func, cl_object, *args): try: return ctx_dict[cl_object][args] except KeyError: - first_arg_dependent_memoized_functions.append(func) + _first_arg_dependent_caches.append(ctx_dict) arg_dict = ctx_dict.setdefault(cl_object, {}) result = func(cl_object, *args) arg_dict[args] = result @@ -98,13 +100,8 @@ def clear_first_arg_caches(): .. versionadded:: 2011.2 """ - for func in first_arg_dependent_memoized_functions: - try: - ctx_dict = func._pyopencl_first_arg_dep_memoize_dic - except AttributeError: - pass - else: - ctx_dict.clear() + for cache in _first_arg_dependent_caches: + cache.clear() import atexit atexit.register(clear_first_arg_caches) @@ -494,4 +491,184 @@ def dtype_to_c_struct(device, dtype): +# {{{ code generation/templating helper + +def _process_code_for_macro(code): + code = code.replace("//CL//", "\n") + + if "//" in code: + raise RuntimeError("end-of-line comments ('//') may not be used in " + "code snippets") + + return code.replace("\n", " \\\n") + +class _SimpleTextTemplate: + def __init__(self, txt): + self.txt = txt + + def render(self, context): + return self.txt + +class _PrintfTextTemplate: + def __init__(self, txt): + self.txt = txt + + def render(self, context): + return self.txt % context + +class _MakoTextTemplate: + def __init__(self, txt): + from mako.template import Template + self.template = Template(txt, strict_undefined=True) + + def render(self, context): + return self.template.render(**context) + + + + +class _ArgumentPlaceholder: + def __init__(self, typename, name): + self.typename = typename + self.name = name + + def to_arg(self, type_dict): + if isinstance(self.typename, str): + try: + dtype = type_dict[self.typename] + except KeyError: + from pyopencl.compyte.dtypes import NAME_TO_DTYPE + dtype = NAME_TO_DTYPE[self.typename] + else: + dtype = np.dtype(self.typename) + + return self.target_class(dtype, self.name) + +class _VectorArgPlaceholder(_ArgumentPlaceholder): + target_class = VectorArg + +class _ScalarArgPlaceholder(_ArgumentPlaceholder): + target_class = ScalarArg + + + + +class _TemplateRenderer(object): + def __init__(self, template, type_values, var_values, context=None, options=[]): + self.template = template + self.type_dict = dict(type_values) + self.var_dict = dict(var_values) + + for name in self.var_dict: + if name.startswith("macro_"): + self.var_dict[name] = _process_code_for_macro(self.var_dict[name]) + + self.context = context + self.options = options + + def __call__(self, txt): + if txt is None: + return txt + + result = self.template.get_text_template(txt).render(self.var_dict) + + # substitute in types + for name, dtype in self.type_dict.iteritems(): + result = re.sub(r"\b%s\b" % name, dtype_to_ctype(dtype), result) + + return result + + def get_rendered_kernel(self, txt, kernel_name): + prg = cl.Program(self.context, self(txt)).build(self.options) + + kernel_name_prefix = self.var_dict.get("kernel_name_prefix") + if kernel_name_prefix is not None: + kernel_name = kernel_name_prefix+kernel_name + + return getattr(prg, kernel_name) + + def render_argument_list(self, arguments, more_arguments): + all_args = [] + + if isinstance(arguments, str): + all_args.extend(arguments.split(",")) + else: + all_args.extend(arguments) + + if isinstance(more_arguments, str): + all_args.extend(more_arguments.split(",")) + else: + all_args.extend(more_arguments) + + from pyopencl.compyte.dtypes import parse_c_arg_backend + parsed_args = [] + for arg in all_args: + if isinstance(arg, str): + ph = parse_c_arg_backend(arg, + _ScalarArgPlaceholder, _VectorArgPlaceholder, + name_to_dtype=lambda x: x) + parsed_arg = ph.to_arg(self.type_dict) + elif isinstance(arg, Argument): + parsed_arg = arg + elif isinstance(arg, tuple): + ph = _ScalarArgPlaceholder(arg[0], arg[1]) + parsed_arg = ph.to_arg(self.type_dict) + + parsed_args.append(parsed_arg) + + return parsed_args + + + + +class KernelTemplateBase(object): + def __init__(self, template_processor=None): + self.template_processor = template_processor + + self.build_cache = {} + _first_arg_dependent_caches.append(self.build_cache) + + def get_preamble(self): + pass + + _TEMPLATE_PROCESSOR_PATTERN = re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//") + + @memoize_method + def get_text_template(self, txt): + proc_match = self._TEMPLATE_PROCESSOR_PATTERN.match(txt) + tpl_processor = None + + chop_first = 0 + if proc_match is not None: + tpl_processor = proc_match.group(1) + # chop off //CL// mark + txt = txt[len(proc_match.group(0)):] + if tpl_processor is None: + tpl_processor = self.template_processor + + if tpl_processor is None or tpl_processor == "none": + return _SimpleTextTemplate(txt) + elif tpl_processor == "printf": + return _PrintfTextTemplate(txt) + elif tpl_processor == "mako": + return _MakoTextTemplate(txt) + else: + raise RuntimeError("unknown template processor '%s'" % proc_match.group(1)) + + def get_renderer(self, type_values, var_values, context=None, options=[]): + return _TemplateRenderer(self, type_values, var_values) + + def build(self, context, *args, **kwargs): + """Provide caching for an :meth:`build_inner`.""" + + cache_key = (context, args, tuple(sorted(kwargs.iteritems()))) + try: + return self.build_cache[cache_key] + except KeyError: + result = self.build_inner(context, *args, **kwargs) + self.build_cache[cache_key] = result + return result + +# }}} + # vim: foldmethod=marker -- GitLab