diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 15ade40d93bba3192590a6cff7dd1f00bc0e4a37..3efbeaafd07f7ccf66ca4ac7de3d936401f4e091 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -139,23 +139,36 @@ def dtype_to_ctype(dtype): # {{{ C argument lists -------------------------------------------------------- class Argument: - def __init__(self, dtype, name): + def __init__(self, dtype, name, vector_len=1): self.dtype = numpy.dtype(dtype) self.name = name + self.vector_len = vector_len def __repr__(self): - return "%s(%r, %s)" % ( + return "%s(%r, %s, %d)" % ( self.__class__.__name__, self.name, - self.dtype) + self.dtype, + self.vector_len) class VectorArg(Argument): def declarator(self): - return "__global %s *%s" % (dtype_to_ctype(self.dtype), self.name) + if self.vector_len == 1: + vlen_str = "" + else: + vlen_str = str(self.vector_len) + + return "__global %s%s *%s" % (dtype_to_ctype(self.dtype), + vlen_str, self.name) class ScalarArg(Argument): def declarator(self): - return "%s %s" % (dtype_to_ctype(self.dtype), self.name) + if self.vector_len == 1: + vlen_str = "" + else: + vlen_str = str(self.vector_len) + + return "%s%s %s" % (dtype_to_ctype(self.dtype), vlen_str, self.name) @@ -171,7 +184,7 @@ def parse_c_arg(c_arg): # process and remove declarator import re - decl_re = re.compile(r"(\**)\s*([_a-zA-Z0-9]+)(\s*\[[ 0-9]*\])*\s*$") + decl_re = re.compile(r"(\**)\s*([_a-zA-Z]+)(\s*\[[ 0-9]*\])*\s*$") decl_match = decl_re.search(c_arg) if decl_match is None: @@ -187,10 +200,20 @@ def parse_c_arg(c_arg): tp = c_arg[:decl_match.start()] tp = " ".join(tp.split()) + type_re = re.compile(r"^([a-z ]+)([0-9]*)$") + type_match = type_re.match(tp) + if not type_match: + raise RuntimeError("type '%s' did not match expected shape of type" + % tp) + + tp = type_match.group(1) + if type_match.group(2): + vector_len = int(type_match.group(2)) + else: + vector_len = 1 + if tp == "float": dtype = numpy.float32 elif tp == "double": dtype = numpy.float64 - elif tp == "pycuda::complex<float>": dtype = numpy.complex64 - elif tp == "pycuda::complex<double>": dtype = numpy.complex128 elif tp in ["int", "signed int"]: dtype = numpy.int32 elif tp in ["unsigned", "unsigned int"]: dtype = numpy.uint32 elif tp in ["long", "long int"]: dtype = numpy.int64 @@ -203,7 +226,7 @@ def parse_c_arg(c_arg): elif tp in ["bool"]: dtype = numpy.bool else: raise ValueError, "unknown type '%s'" % tp - return arg_class(dtype, name) + return arg_class(dtype, name, vector_len) # }}}