diff --git a/loopy/auto_test.py b/loopy/auto_test.py index e992fa3d48a625c3cb982eed1409be847aadde5f..479b898be610f6c9694be14f2095764ff14b767c 100644 --- a/loopy/auto_test.py +++ b/loopy/auto_test.py @@ -79,7 +79,8 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters): import pyopencl as cl import pyopencl.array as cl_array - from loopy.kernel.data import ValueArg, GlobalArg, ImageArg, TemporaryVariable + from loopy.kernel.data import ValueArg, GlobalArg, ImageArg, \ + TemporaryVariable, ConstantArg from pymbolic import evaluate @@ -107,7 +108,8 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters): ref_arg_data.append(None) - elif arg.arg_class is GlobalArg or arg.arg_class is ImageArg: + elif arg.arg_class is GlobalArg or arg.arg_class is ImageArg \ + or arg.arg_class is ConstantArg: if arg.shape is None or any(saxis is None for saxis in arg.shape): raise LoopyError("array '%s' needs known shape to use automatic " "testing" % arg.name) @@ -196,7 +198,8 @@ def make_args(kernel, impl_arg_info, queue, ref_arg_data, parameters): import pyopencl as cl import pyopencl.array as cl_array - from loopy.kernel.data import ValueArg, GlobalArg, ImageArg, TemporaryVariable + from loopy.kernel.data import ValueArg, GlobalArg, ImageArg,\ + TemporaryVariable, ConstantArg from pymbolic import evaluate @@ -229,7 +232,8 @@ def make_args(kernel, impl_arg_info, queue, ref_arg_data, parameters): args[arg.name] = cl.image_from_array( queue.context, arg_desc.ref_pre_run_array.get()) - elif arg.arg_class is GlobalArg: + elif arg.arg_class is GlobalArg or\ + arg.arg_class is ConstantArg: shape = evaluate(arg.unvec_shape, parameters) strides = evaluate(arg.unvec_strides, parameters) diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index 6d06d8a79d800d9b874537c3c5086ec0af5cb14a..be65d3f89139f900c3be7c3c4204bb0bdb4cbfcf 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -176,7 +176,7 @@ class ExpressionToCExpressionMapper(IdentityMapper): lambda expr: evaluate(expr, self.codegen_state.var_subst_map), self.codegen_state.vectorization_info) - from loopy.kernel.data import ImageArg, GlobalArg, TemporaryVariable + from loopy.kernel.data import ImageArg, GlobalArg, TemporaryVariable, ConstantArg if isinstance(ary, ImageArg): extra_axes = 0 @@ -209,9 +209,10 @@ class ExpressionToCExpressionMapper(IdentityMapper): raise NotImplementedError( "non-floating-point images not supported for now") - elif isinstance(ary, (GlobalArg, TemporaryVariable)): + elif isinstance(ary, (GlobalArg, TemporaryVariable, ConstantArg)): if len(access_info.subscripts) == 0: - if isinstance(ary, GlobalArg): + if isinstance(ary, GlobalArg) or\ + isinstance(ary, ConstantArg): # unsubscripted global args are pointers result = var(access_info.array_name)[0] diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 9c2d428b03c325d3ca8b6af398beb79406bd5170..f0436099c6127e6426b03df2c48342b6ee99c67f 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -486,11 +486,11 @@ class OpenCLCASTBuilder(CASTBuilder): return CLImage(num_target_axes, mode, name) def get_constant_arg_decl(self, name, shape, dtype, is_written): - from loopy.codegen import POD # uses the correct complex type + from loopy.target.c import POD # uses the correct complex type from cgen import RestrictPointer, Const from cgen.opencl import CLConstant - arg_decl = RestrictPointer(POD(dtype, name)) + arg_decl = RestrictPointer(POD(self, dtype, name)) if not is_written: arg_decl = Const(arg_decl) diff --git a/test/test_loopy.py b/test/test_loopy.py index 9d1e2d155c741c5c1a8919ecf846c00931d7c1fb..d56caf0e86e0c4964b6cb8c7164282ff701d953f 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1459,6 +1459,25 @@ def test_unr_and_conditionals(ctx_factory): lp.auto_test_vs_ref(ref_knl, ctx, knl) +def test_constant_array_args(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel('{[k]: 0<=k Tcond[k] = T[k] < 0.5 + if Tcond[k] + cp[k] = 2 * T[k] + Tcond[k] + end + end + """, + [lp.ConstantArg('T', shape=(200,), dtype=np.float32), + '...']) + + knl = lp.fix_parameters(knl, n=200) + + print(lp.generate_code_v2(knl).device_code()) + if __name__ == "__main__": if len(sys.argv) > 1: