diff --git a/loopy/compiled.py b/loopy/compiled.py index 4543e96b0c3e1517c1a7044e10143711aaf134cf..f4b862366975f502e06ad59a5f14d520ea964eae 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -44,6 +44,7 @@ class ArgumentUnpacker(object): def __init__(self, kernel): # a list of items like (arg_name, [(index, unpacked_name), ...]) self.unpackable_args = [] + self.arg_name_to_base_arg_name = {} from loopy.kernel.array import ArrayBase, SeparateArrayArrayDimTag for arg in kernel.args: @@ -56,6 +57,10 @@ class ArgumentUnpacker(object): log_shape = [] for shape_i, dim_tag in zip(arg.shape, arg.dim_tags): if isinstance(dim_tag, SeparateArrayArrayDimTag): + if not isinstance(shape_i, int): + raise TypeError("argument '%s' has non-integer " + "separate-array axis" % arg.name) + log_shape.append(shape_i) if not log_shape: @@ -69,6 +74,9 @@ class ArgumentUnpacker(object): self.unpackable_args.append( (arg.name, unpack_data)) + for index, sub_arg_name in unpack_data: + self.arg_name_to_base_arg_name[sub_arg_name] = arg.name + def __call__(self, kernel_kwargs): kernel_kwargs = kernel_kwargs.copy() @@ -248,7 +256,12 @@ class CompiledKernel: from loopy.kernel.tools import add_argument_dtypes if arg_to_dtype_set: - kernel = add_argument_dtypes(kernel, dict(arg_to_dtype_set)) + arg_to_dtype = {} + for arg, dtype in arg_to_dtype_set: + arg_to_dtype[self.argument_unpacker + .arg_name_to_base_arg_name.get(arg, arg)] = dtype + + kernel = add_argument_dtypes(kernel, arg_to_dtype) from loopy.preprocess import infer_unknown_types kernel = infer_unknown_types(kernel, expect_completion=True) @@ -365,13 +378,17 @@ class CompiledKernel: code_op = kwargs.pop("code_op", None) warn_numpy = kwargs.pop("warn_numpy", None) + kwargs = self.argument_unpacker(kwargs) + # {{{ process arg types, get cl kernel import loopy as lp arg_to_dtype = {} - for arg in self.kernel.args: - val = kwargs.get(arg.name) + for arg_name, val in kwargs.iteritems(): + arg_name = self.argument_unpacker \ + .arg_name_to_base_arg_name.get(arg_name, arg_name) + arg = self.kernel.arg_dict[arg_name] if arg.dtype is None and val is not None: try: @@ -390,7 +407,6 @@ class CompiledKernel: # }}} - kwargs = self.argument_unpacker(kwargs) kwargs.update( kernel_info.domain_parameter_finder(kwargs))