diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 8bdc72d54a91c6e8b4f9ec0ca3053831627d3eae..15840180b1c19a4ad9edd31bd87a50ba9f946c48 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -1070,7 +1070,7 @@ def guess_var_shape(kernel, var_name): if n_axes == 1: # Leave shape undetermined--we can live with that for 1D. - shape = (None,) + shape = None else: raise LoopyError("cannot determine access range for '%s': " "undetermined index in subscript(s) '%s'" diff --git a/loopy/target/pyopencl_execution.py b/loopy/target/pyopencl_execution.py index 2da25ba39ceef38a4af105913973226bd3773729..975c691a74d0d17bdca39243f515c5d04284893d 100644 --- a/loopy/target/pyopencl_execution.py +++ b/loopy/target/pyopencl_execution.py @@ -328,7 +328,8 @@ def generate_arg_setup(gen, kernel, implemented_data_info, options): # {{{ allocate written arrays, if needed if is_written and arg.arg_class in [lp.GlobalArg, lp.ConstantArg] \ - and arg.shape is not None: + and arg.shape is not None \ + and all(si is not None for si in arg.shape): if not isinstance(arg.dtype, NumpyType): raise LoopyError("do not know how to pass arg of type '%s'" diff --git a/test/test_loopy.py b/test/test_loopy.py index 563964cf04dfbce5d8983b66010863ef36a74ce7..2c9425cc7d3e3cec21f197bbdb0253a0fa6f89df 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -518,6 +518,32 @@ def test_arg_guessing_with_reduction(ctx_factory): print(knl) print(lp.CompiledKernel(ctx, knl).get_highlighted_code()) + +def test_unknown_arg_shape(ctx_factory): + ctx = ctx_factory() + from loopy.target.pyopencl import PyOpenCLTarget + from loopy.compiled import CompiledKernel + bsize = [256, 0] + + knl = lp.make_kernel( + "{[i,j]: 0<=i gid = i/256 + start = gid*256 + for j + a[start + j] = a[start + j] + j + end + end + """, + seq_dependencies=True, + name="uniform_l", + target=PyOpenCLTarget(), + assumptions="m<=%d and m>=1 and n mod %d = 0" % (bsize[0], bsize[0])) + + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32)) + cl_kernel_info = CompiledKernel(ctx, knl).cl_kernel_info(frozenset()) # noqa + # }}}