diff --git a/examples/hello-loopy.py b/examples/hello-loopy.py index a6b8e6d8b8ecdeabbd164d31352958975edf9ec9..a35d7272de26a4993f0d13682f8aba38091ea792 100644 --- a/examples/hello-loopy.py +++ b/examples/hello-loopy.py @@ -15,14 +15,15 @@ a = cl.array.arange(queue, n, dtype=np.float32) # ----------------------------------------------------------------------------- # generation (loopy bits start here) # ----------------------------------------------------------------------------- -knl = lp.make_kernel(ctx.devices[0], - "{[i]: 0<=i row_len = a_rowstarts[i+1] - a_rowstarts[i]", - "ax[i] = sum(jj, a_values[a_rowstarts[i]+jj])", + "a_sum[i] = sum(jj, a_values[[a_rowstarts[i]+jj]])", ], [ - lp.GlobalArg("a_rowstarts", np.int32), - lp.GlobalArg("a_indices", np.int32), + lp.GlobalArg("a_rowstarts", np.int32, shape="auto"), + lp.GlobalArg("a_indices", np.int32, shape="auto"), lp.GlobalArg("a_values", dtype), - lp.GlobalArg("x", dtype), - lp.GlobalArg("ax", dtype), + lp.GlobalArg("a_sum", dtype, shape="auto"), lp.ValueArg("n", np.int32), ], assumptions="n>=1 and row_len>=1") @@ -676,14 +671,13 @@ def test_dependent_loop_bounds_2(ctx_factory): [ "<> row_start = a_rowstarts[i]", "<> row_len = a_rowstarts[i+1] - row_start", - "ax[i] = sum(jj, a_values[row_start+jj])", + "ax[i] = sum(jj, a_values[[row_start+jj]])", ], [ - lp.GlobalArg("a_rowstarts", np.int32), - lp.GlobalArg("a_indices", np.int32), + lp.GlobalArg("a_rowstarts", np.int32, shape="auto"), + lp.GlobalArg("a_indices", np.int32, shape="auto"), lp.GlobalArg("a_values", dtype), - lp.GlobalArg("x", dtype), - lp.GlobalArg("ax", dtype), + lp.GlobalArg("ax", dtype, shape="auto"), lp.ValueArg("n", np.int32), ], assumptions="n>=1 and row_len>=1") @@ -718,7 +712,7 @@ def test_dependent_loop_bounds_3(ctx_factory): "a[i,jj] = 1", ], [ - lp.GlobalArg("a_row_lengths", np.int32), + lp.GlobalArg("a_row_lengths", np.int32, shape="auto"), lp.GlobalArg("a", dtype, shape=("n,n"), order="C"), lp.ValueArg("n", np.int32), ]) @@ -1029,17 +1023,35 @@ def test_write_parameter(ctx_factory): ], assumptions="n>=1") - try: + import pytest + with pytest.raises(RuntimeError): lp.CompiledKernel(ctx, knl).get_code() - except RuntimeError, e: - assert "may not be written" in str(e) - pass # expected! - else: - assert False # expecting an error +def test_arg_shape_guessing(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], [ + "{[i,j]: 0<=i,j 1: