Skip to content
Snippets Groups Projects
Commit 8064e6c5 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Make separate-array data axis tagging interoperate with run-time typing.

parent 671d3c0a
No related branches found
No related tags found
No related merge requests found
......@@ -44,6 +44,7 @@ class ArgumentUnpacker(object):
def __init__(self, kernel):
# a list of items like (arg_name, [(index, unpacked_name), ...])
self.unpackable_args = []
self.arg_name_to_base_arg_name = {}
from loopy.kernel.array import ArrayBase, SeparateArrayArrayDimTag
for arg in kernel.args:
......@@ -56,6 +57,10 @@ class ArgumentUnpacker(object):
log_shape = []
for shape_i, dim_tag in zip(arg.shape, arg.dim_tags):
if isinstance(dim_tag, SeparateArrayArrayDimTag):
if not isinstance(shape_i, int):
raise TypeError("argument '%s' has non-integer "
"separate-array axis" % arg.name)
log_shape.append(shape_i)
if not log_shape:
......@@ -69,6 +74,9 @@ class ArgumentUnpacker(object):
self.unpackable_args.append(
(arg.name, unpack_data))
for index, sub_arg_name in unpack_data:
self.arg_name_to_base_arg_name[sub_arg_name] = arg.name
def __call__(self, kernel_kwargs):
kernel_kwargs = kernel_kwargs.copy()
......@@ -248,7 +256,12 @@ class CompiledKernel:
from loopy.kernel.tools import add_argument_dtypes
if arg_to_dtype_set:
kernel = add_argument_dtypes(kernel, dict(arg_to_dtype_set))
arg_to_dtype = {}
for arg, dtype in arg_to_dtype_set:
arg_to_dtype[self.argument_unpacker
.arg_name_to_base_arg_name.get(arg, arg)] = dtype
kernel = add_argument_dtypes(kernel, arg_to_dtype)
from loopy.preprocess import infer_unknown_types
kernel = infer_unknown_types(kernel, expect_completion=True)
......@@ -365,13 +378,17 @@ class CompiledKernel:
code_op = kwargs.pop("code_op", None)
warn_numpy = kwargs.pop("warn_numpy", None)
kwargs = self.argument_unpacker(kwargs)
# {{{ process arg types, get cl kernel
import loopy as lp
arg_to_dtype = {}
for arg in self.kernel.args:
val = kwargs.get(arg.name)
for arg_name, val in kwargs.iteritems():
arg_name = self.argument_unpacker \
.arg_name_to_base_arg_name.get(arg_name, arg_name)
arg = self.kernel.arg_dict[arg_name]
if arg.dtype is None and val is not None:
try:
......@@ -390,7 +407,6 @@ class CompiledKernel:
# }}}
kwargs = self.argument_unpacker(kwargs)
kwargs.update(
kernel_info.domain_parameter_finder(kwargs))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment