From 582139a8bbb0951cc8e6589b6cd35534c7b6df9c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 5 Jun 2013 10:46:53 -0400 Subject: [PATCH] Add argument unpacking for separate-array tagged arguments. --- loopy/compiled.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/loopy/compiled.py b/loopy/compiled.py index c5a44b837..4543e96b0 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -31,10 +31,64 @@ import numpy as np from pytools import Record, memoize_method +# {{{ object array argument unpacker + +class ArgumentUnpacker(object): + """For argument arrays with axes tagged to be implemented as separate + arrays, this class provides preprocessing of the passed arguments so that + all sub-arrays may be passed in one object array (under the original, + un-split argument name) and are unpacked into separate arrays before being + passed to the kernel. + """ + + def __init__(self, kernel): + # a list of items like (arg_name, [(index, unpacked_name), ...]) + self.unpackable_args = [] + + from loopy.kernel.array import ArrayBase, SeparateArrayArrayDimTag + for arg in kernel.args: + if not isinstance(arg, ArrayBase): + continue + + if arg.shape is None or arg.dim_tags is None: + continue + + log_shape = [] + for shape_i, dim_tag in zip(arg.shape, arg.dim_tags): + if isinstance(dim_tag, SeparateArrayArrayDimTag): + log_shape.append(shape_i) + + if not log_shape: + continue + + from pytools import indices_in_shape + unpack_data = [ + (i, arg.name + "".join("_s%d" % sub_i for sub_i in i)) + for i in indices_in_shape(log_shape)] + + self.unpackable_args.append( + (arg.name, unpack_data)) + + def __call__(self, kernel_kwargs): + kernel_kwargs = kernel_kwargs.copy() + + for arg_name, subscripts_and_names in self.unpackable_args: + if arg_name in kernel_kwargs: + arg = kernel_kwargs[arg_name] + for index, unpacked_name in subscripts_and_names: + assert unpacked_name not in kernel_kwargs + kernel_kwargs[unpacked_name] = arg[index] + del kernel_kwargs[arg_name] + + return kernel_kwargs + +# }}} + + # {{{ domain parameter finder class DomainParameterFinder(object): - """Finds parameters from shapes of passed arguments.""" + """Finds domain parameters from shapes of passed arguments.""" def __init__(self, kernel, cl_arg_info): # a mapping from parameter names to a list of tuples @@ -184,6 +238,8 @@ class CompiledKernel: self.codegen_kwargs = codegen_kwargs self.options = options + self.argument_unpacker = ArgumentUnpacker(kernel) + @memoize_method def get_kernel_info(self, arg_to_dtype_set): kernel = self.kernel @@ -334,6 +390,7 @@ class CompiledKernel: # }}} + kwargs = self.argument_unpacker(kwargs) kwargs.update( kernel_info.domain_parameter_finder(kwargs)) -- GitLab