diff --git a/loopy/compiled.py b/loopy/compiled.py index c66f3fb5f3ec9040e13b0e315e42ca8bb59e4760..ee87023356074c8bbd21b21cab6d717916480453 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -31,19 +31,34 @@ import numpy as np from pytools import Record, memoize_method -# {{{ object array argument unpacker +# {{{ object array argument packing -class ArgumentUnpacker(object): +class _PackingInfo(Record): + """ + .. attribute:: name + .. attribute:: sep_shape + + .. attribute:: subscripts_and_names + + A list of type ``[(index, unpacked_name), ...]``. + """ + + +class SeparateArrayPackingController(object): """For argument arrays with axes tagged to be implemented as separate - arrays, this class provides preprocessing of the passed arguments so that + arrays, this class provides preprocessing of the incoming arguments so that all sub-arrays may be passed in one object array (under the original, un-split argument name) and are unpacked into separate arrays before being passed to the kernel. + + It also repacks outgoing arrays of this type back into an object array. + + .. attribute:: arg_name_to_base_arg_name """ def __init__(self, kernel): - # a list of items like (arg_name, [(index, unpacked_name), ...]) - self.unpackable_args = [] + # map from arg name + self.packing_info = {} self.arg_name_to_base_arg_name = {} from loopy.kernel.array import ArrayBase, SeparateArrayArrayDimTag @@ -54,42 +69,65 @@ class ArgumentUnpacker(object): if arg.shape is None or arg.dim_tags is None: continue - log_shape = [] + sep_shape = [] for shape_i, dim_tag in zip(arg.shape, arg.dim_tags): if isinstance(dim_tag, SeparateArrayArrayDimTag): if not isinstance(shape_i, int): raise TypeError("argument '%s' has non-integer " "separate-array axis" % arg.name) - log_shape.append(shape_i) + sep_shape.append(shape_i) - if not log_shape: + if not sep_shape: continue from pytools import indices_in_shape - unpack_data = [ + subscripts_and_names = [ (i, arg.name + "".join("_s%d" % sub_i for sub_i in i)) - for i in indices_in_shape(log_shape)] + for i in indices_in_shape(sep_shape)] - self.unpackable_args.append( - (arg.name, unpack_data)) + self.packing_info[arg.name] = _PackingInfo( + name=arg.name, + sep_shape=sep_shape, + subscripts_and_names=subscripts_and_names, + is_written=arg.name in kernel.get_written_variables()) - for index, sub_arg_name in unpack_data: + for index, sub_arg_name in subscripts_and_names: self.arg_name_to_base_arg_name[sub_arg_name] = arg.name - def __call__(self, kernel_kwargs): + def unpack(self, kernel_kwargs): + if not self.packing_info: + return kernel_kwargs + kernel_kwargs = kernel_kwargs.copy() - for arg_name, subscripts_and_names in self.unpackable_args: - if arg_name in kernel_kwargs: + for packing_info in self.packing_info.itervalues(): + arg_name = packing_info.name + if packing_info.name in kernel_kwargs: arg = kernel_kwargs[arg_name] - for index, unpacked_name in subscripts_and_names: + for index, unpacked_name in packing_info.subscripts_and_names: assert unpacked_name not in kernel_kwargs kernel_kwargs[unpacked_name] = arg[index] del kernel_kwargs[arg_name] return kernel_kwargs + def pack(self, outputs): + if not self.packing_info: + return outputs + + for packing_info in self.packing_info.itervalues(): + if not packing_info.is_written: + continue + + result = outputs[packing_info.name] = \ + np.zeros(packing_info.sep_shape, dtype=np.object) + + for index, unpacked_name in packing_info.subscripts_and_names: + result[index] = outputs.pop(unpacked_name) + + return outputs + # }}} @@ -247,7 +285,10 @@ class CompiledKernel: self.codegen_kwargs = codegen_kwargs self.options = options - self.argument_unpacker = ArgumentUnpacker(kernel) + self.packing_controller = SeparateArrayPackingController(kernel) + + self.output_names = tuple(arg.name for arg in self.kernel.args + if arg.name in self.kernel.get_written_variables()) @memoize_method def get_kernel_info(self, arg_to_dtype_set): @@ -259,7 +300,7 @@ class CompiledKernel: if arg_to_dtype_set: arg_to_dtype = {} for arg, dtype in arg_to_dtype_set: - arg_to_dtype[self.argument_unpacker + arg_to_dtype[self.packing_controller .arg_name_to_base_arg_name.get(arg, arg)] = dtype kernel = add_argument_dtypes(kernel, arg_to_dtype) @@ -370,6 +411,12 @@ class CompiledKernel: If you want offset arguments (see :attr:`loopy.kernel.data.GlobalArg.offset`) to be set automatically, it must occur *after* the corresponding array argument. + + :arg allocator: + :arg wait_for: + :arg out_host: + :arg warn_numpy: + :arg return_dict: """ allocator = kwargs.pop("allocator", None) @@ -378,8 +425,9 @@ class CompiledKernel: no_run = kwargs.pop("no_run", None) code_op = kwargs.pop("code_op", None) warn_numpy = kwargs.pop("warn_numpy", None) + return_dict = kwargs.pop("return_dict", False) - kwargs = self.argument_unpacker(kwargs) + kwargs = self.packing_controller.unpack(kwargs) # {{{ process arg types, get cl kernel @@ -387,7 +435,7 @@ class CompiledKernel: arg_to_dtype = {} for arg_name, val in kwargs.iteritems(): - arg_name = self.argument_unpacker \ + arg_name = self.packing_controller \ .arg_name_to_base_arg_name.get(arg_name, arg_name) arg = self.kernel.arg_dict[arg_name] @@ -415,7 +463,7 @@ class CompiledKernel: for name in kernel.scalar_loop_args) args = [] - outputs = [] + outputs = {} encountered_numpy = False encountered_cl = False @@ -502,7 +550,7 @@ class CompiledKernel: assert _arg_matches_spec(arg, val, kwargs) if is_written: - outputs.append(val) + outputs[arg.name] = val if arg.arg_class in [lp.GlobalArg, lp.ConstantArg]: args.append(val.base_data) @@ -524,7 +572,14 @@ class CompiledKernel: if out_host is None and (encountered_numpy and not encountered_cl): out_host = True if out_host: - outputs = [o.get(queue=queue) for o in outputs] + outputs = dict( + (name, o.get(queue=queue)) + for name, o in outputs.iteritems()) + + outputs = self.packing_controller.pack(outputs) + + if not return_dict: + outputs = tuple(outputs[name] for name in self.output_names) return evt, outputs