diff --git a/loopy/__init__.py b/loopy/__init__.py index 049327d034bc48ff11d40abc92a955088ef6d5b5..4179575847d6a1a102677db7ff4e27bbe84d63fd 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -11,7 +11,6 @@ register_mpz_with_pymbolic() -# TODO: Constant memory (plus check for count) # TODO: Reuse of previously split dimensions for prefetch # (Or general merging) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index e8415bdbba48a71424c0195482c697a5374545c3..4048ca971da518cb6a086c26a5174ab6a8b85b68 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -145,7 +145,7 @@ def generate_code(kernel): Define, Line, Const, LiteralLines, Initializer) from cgen.opencl import (CLKernel, CLGlobal, CLRequiredWorkGroupSize, - CLLocal, CLImage) + CLLocal, CLImage, CLConstant) from loopy.symbolic import LoopyCCodeMapper my_ccm = LoopyCCodeMapper(kernel) @@ -180,7 +180,10 @@ def generate_code(kernel): arg_decl = restrict_ptr_if_not_nvidia( POD(arg.dtype, arg.name)) if arg_decl.name in kernel.input_vectors(): - arg_decl = Const(arg_decl) + if arg.constant_mem: + arg_decl = CLConstant(Const(arg_decl)) + else: + arg_decl = Const(arg_decl) arg_decl = CLGlobal(arg_decl) elif isinstance(arg, ImageArg): if arg.name in kernel.input_vectors(): diff --git a/loopy/kernel.py b/loopy/kernel.py index f3dff96bdc463987f776c7ac1223bd0f2b896b95..b8f9d97537767c7a8f4c7af989cda00c3e9b0c44 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -14,7 +14,7 @@ import pyopencl as cl class ArrayArg: def __init__(self, name, dtype, strides=None, shape=None, order="C", - offset=0): + offset=0, constant_mem=False): """ All of the following are optional. Specify either strides or shape. @@ -49,6 +49,8 @@ class ArrayArg: self.strides = strides self.offset = offset + self.constant_mem = constant_mem + def __repr__(self): return "" % (self.name, self.dtype) @@ -493,6 +495,13 @@ class LoopKernel(Record): msg(4, "using more local memory than available--" "possibly OK due to cache nature") + const_arg_count = sum( + 1 for arg in self.args + if isinstance(arg, ArrayArg) and arg.constant_mem) + + if const_arg_count > self.device.max_constant_args: + msg(5, "too many constant arguments") + max_severity = 0 for sev, msg in msgs: max_severity = max(sev, max_severity)