From 1df2643797bd91f4da95ba6bd8d359e20f9cf228 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sat, 8 Jun 2013 19:05:45 -0400 Subject: [PATCH] Fix argument shape/dtype checks. --- loopy/codegen/__init__.py | 2 ++ loopy/compiled.py | 7 ++++--- loopy/kernel/array.py | 2 ++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index f08d2c23a..cdc917743 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -226,6 +226,7 @@ class CLArgumentInfo(Record): Strides in multiples of ``dtype.itemsize``. .. attribute:: offset_for_name + .. attribute:: allows_offset .. attribute:: arg_class """ @@ -283,6 +284,7 @@ def generate_code(kernel, with_annotation=False, shape=None, strides=None, offset_for_name=None, + allows_offset=None, arg_class=ValueArg)) else: diff --git a/loopy/compiled.py b/loopy/compiled.py index f4b862366..c66f3fb5f 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -162,7 +162,7 @@ class DomainParameterFinder(object): def _arg_matches_spec(arg, val, other_args): import loopy as lp - if isinstance(arg, lp.GlobalArg): + if arg.shape is not None and arg.arg_class is not lp.ImageArg: from pymbolic import evaluate if arg.dtype != val.dtype: @@ -177,13 +177,14 @@ def _arg_matches_spec(arg, val, other_args): "(got: %s, expected: %s)" % (arg.name, val.shape, shape)) - strides = evaluate(arg.numpy_strides, other_args) + itemsize = arg.dtype.itemsize + strides = tuple(itemsize*i for i in evaluate(arg.strides, other_args)) if strides != tuple(val.strides): raise ValueError("strides mismatch on argument '%s' " "(got: %s, expected: %s)" % (arg.name, val.strides, strides)) - if val.offset != 0 and arg.offset == 0: + if val.offset != 0 and not arg.allows_offset: raise ValueError("Argument '%s' does not allow arrays " "with offsets. Try passing default_offset=loopy.auto " "to make_kernel()." % arg.name) diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index e12bf77d5..f90100f00 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -629,6 +629,7 @@ class ArrayBase(Record): shape=shape, strides=strides, offset_for_name=None, + allows_offset=bool(self.offset), arg_class=type(self))) if self.offset: @@ -642,6 +643,7 @@ class ArrayBase(Record): shape=None, strides=None, offset_for_name=full_name, + allows_offset=None, arg_class=None)) return -- GitLab