diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 84b0a4a74bdc0e4dc32b5823c50a4a9c0a8d9f25..3735b2d510aef1ada1643bce9ff07e797bfb210c 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -934,7 +934,8 @@ class ArrayBase(ImmutableRecord): return len(target_axes) def num_user_axes(self, require_answer=True): - if self.shape is not None: + from loopy import auto + if self.shape not in (None, auto): return len(self.shape) if self.dim_tags is not None: return len(self.dim_tags) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 9e6e8db666bab61e85981dc697c2d40cea6a18a6..e6544b34a55af97a1a15e86f7d74855e08e53116 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -525,7 +525,7 @@ class TemporaryVariable(ArrayBase): "_base_storage_access_may_be_aliasing", ] - def __init__(self, name, dtype=None, shape=(), address_space=None, + def __init__(self, name, dtype=None, shape=auto, address_space=None, dim_tags=None, offset=0, dim_names=None, strides=None, order=None, base_indices=None, storage_shape=None, base_storage=None, initializer=None, read_only=False, @@ -579,7 +579,10 @@ class TemporaryVariable(ArrayBase): if shape is auto: shape = initializer.shape - + else: + if shape != initializer.shape: + raise LoopyError("Shape of '{}' does not match that of the" + " initializer.".format(name)) else: raise LoopyError( "temporary variable '%s': " @@ -589,7 +592,7 @@ class TemporaryVariable(ArrayBase): if order is None: order = "C" - if base_indices is None: + if base_indices is None and shape is not auto: base_indices = (0,) * len(shape) if not read_only and initializer is not None: diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 04d436043daed74362ebabd96e18bf1d4d6d4a6c..4569be50367b3063999656bcd1de9d76f98e8c0a 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -551,8 +551,10 @@ class OpenCLCASTBuilder(CFamilyASTBuilder): from loopy.kernel.data import TemporaryVariable, AddressSpace ecm = codegen_state.expression_to_code_mapper.with_assignments( { - old_val_var: TemporaryVariable(old_val_var, lhs_dtype), - new_val_var: TemporaryVariable(new_val_var, lhs_dtype), + old_val_var: TemporaryVariable(old_val_var, lhs_dtype, + shape=()), + new_val_var: TemporaryVariable(new_val_var, lhs_dtype, + shape=()), }) lhs_expr_code = ecm(lhs_expr, prec=PREC_NONE, type_context=None) diff --git a/test/test_loopy.py b/test/test_loopy.py index 61a3f167be66f1c99adc3a52473d8edc747479e1..f9345d5b6cd9b97da80bb2ff8e5c6c657c199402 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -68,8 +68,8 @@ def test_globals_decl_once_with_multi_subprogram(ctx_factory): out[ii] = 2*out[ii]+cnst[ii]{id=second} """, [lp.TemporaryVariable( - 'cnst', shape=('n'), initializer=cnst, - address_space=lp.AddressSpace.GLOBAL, + 'cnst', initializer=cnst, + scope=lp.AddressSpace.GLOBAL, read_only=True), '...']) knl = lp.fix_parameters(knl, n=16) knl = lp.add_barrier(knl, "id:first", "id:second")