diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index f3bdf3c7028bdb8e125a020e2077bcbb4194756f..52318be382d973085ab300313e4e5ce76b79ecdd 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -267,6 +267,41 @@ class ElementwiseKernel: return cl.enqueue_nd_range_kernel(queue, kernel, gs, ls, wait_for=wait_for) +# }}} + +# {{{ template + +class ElementwiseTemplate(KernelTemplateBase): + def __init__(self, + arguments, operation, name="elwise", preamble="", + template_processor=None): + + KernelTemplateBase.__init__(self, template_processor=template_processor) + self.arguments = arguments + self.operation = operation + self.name = name + self.preamble = preamble + + def build_inner(self, context, type_values, var_values, + more_preamble="", more_arguments=(), declare_types=(), + options=()): + renderer = self.get_renderer(type_values, var_values, context, options) + + arg_list = renderer.render_argument_list(self.arguments, more_arguments) + type_decl_preamble = renderer.get_type_decl_preamble( + context.devices[0], declare_types, arg_list) + + return ElementwiseKernel(context, + arg_list, renderer(self.operation), + name=renderer(self.name), options=list(options), + preamble=( + type_decl_preamble + + "\n" + renderer(self.preamble + "\n" + more_preamble)), + auto_preamble=True) + +# }}} + +# {{{ kernels supporting array functionality @context_dependent_memoize def get_take_kernel(context, dtype, idx_dtype, vec_count=1): @@ -762,3 +797,7 @@ def get_if_positive_kernel(context, crit_dtype, dtype): ], "result[i] = crit[i] > 0 ? then_[i] : else_[i]", name="if_positive") + +# }}} + +# vim: fdm=marker:filetype=pyopencl diff --git a/pyopencl/tools.py b/pyopencl/tools.py index e82f8b63e114083b3d6ff84bf7acbbde0a78a365..ab70927cc8d6f532e18fe7ef8a6ecf3f2bc7086c 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -323,6 +323,15 @@ class _CDeclList: self.declarations.append(cdecl) self.declared_dtypes.add(dtype) + def visit_arguments(self, arguments): + for arg in arguments: + dtype = arg.dtype + if dtype in [np.float64 or np.complex128]: + self.saw_double = True + + if dtype.kind == "c": + self.saw_complex = True + def get_declarations(self): result = "\n\n".join(self.declarations) @@ -532,18 +541,6 @@ class _ArgumentPlaceholder: self.typename = typename self.name = name - def to_arg(self, type_dict): - if isinstance(self.typename, str): - try: - dtype = type_dict[self.typename] - except KeyError: - from pyopencl.compyte.dtypes import NAME_TO_DTYPE - dtype = NAME_TO_DTYPE[self.typename] - else: - dtype = np.dtype(self.typename) - - return self.target_class(dtype, self.name) - class _VectorArgPlaceholder(_ArgumentPlaceholder): target_class = VectorArg @@ -576,7 +573,7 @@ class _TemplateRenderer(object): for name, dtype in self.type_dict.iteritems(): result = re.sub(r"\b%s\b" % name, dtype_to_ctype(dtype), result) - return result + return str(result) def get_rendered_kernel(self, txt, kernel_name): prg = cl.Program(self.context, self(txt)).build(self.options) @@ -587,37 +584,71 @@ class _TemplateRenderer(object): return getattr(prg, kernel_name) - def render_argument_list(self, arguments, more_arguments): + def parse_type(self, typename): + if isinstance(typename, str): + try: + return self.type_dict[typename] + except KeyError: + from pyopencl.compyte.dtypes import NAME_TO_DTYPE + return NAME_TO_DTYPE[typename] + else: + return np.dtype(typename) + + def render_arg(self, arg_placeholder): + return arg_placeholder.target_class( + self.parse_type(arg_placeholder.typename), + arg_placeholder.name) + + _C_COMMENT_FINDER = re.compile(r"/\*.*?\*/") + + def render_argument_list(self, *arg_lists): all_args = [] - if isinstance(arguments, str): - all_args.extend(arguments.split(",")) - else: - all_args.extend(arguments) + for arg_list in arg_lists: + if isinstance(arg_list, str): + if arg_list.startswith("//CL//"): + arg_list = arg_list[6:] + arg_list = self._C_COMMENT_FINDER.sub("", arg_list) + arg_list = arg_list.replace("\n", " ") - if isinstance(more_arguments, str): - all_args.extend(more_arguments.split(",")) - else: - all_args.extend(more_arguments) + all_args.extend(arg_list.split(",")) + else: + all_args.extend(arg_list) from pyopencl.compyte.dtypes import parse_c_arg_backend parsed_args = [] for arg in all_args: if isinstance(arg, str): + arg = arg.strip() + if not arg: + continue + ph = parse_c_arg_backend(arg, _ScalarArgPlaceholder, _VectorArgPlaceholder, name_to_dtype=lambda x: x) - parsed_arg = ph.to_arg(self.type_dict) + parsed_arg = self.render_arg(ph) + elif isinstance(arg, Argument): parsed_arg = arg elif isinstance(arg, tuple): - ph = _ScalarArgPlaceholder(arg[0], arg[1]) - parsed_arg = ph.to_arg(self.type_dict) + parsed_arg = ScalarArg(self.parse_type(arg[0]), arg[1]) parsed_args.append(parsed_arg) return parsed_args + def get_type_decl_preamble(self, device, decl_type_names, arguments=None): + cdl = _CDeclList(device) + + for typename in decl_type_names: + cdl.add_dtype(self.parse_type(typename)) + + if arguments is not None: + cdl.visit_arguments(arguments) + + return cdl.get_declarations() + + @@ -638,7 +669,6 @@ class KernelTemplateBase(object): proc_match = self._TEMPLATE_PROCESSOR_PATTERN.match(txt) tpl_processor = None - chop_first = 0 if proc_match is not None: tpl_processor = proc_match.group(1) # chop off //CL// mark