From 066c259877de2cba3a8499417d0da5f2605e368b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sat, 25 Jun 2011 17:45:27 -0600 Subject: [PATCH] Fix interoperability of Array with vector types. --- pyopencl/__init__.py | 2 +- pyopencl/compyte | 2 +- pyopencl/tools.py | 31 +++++++------------------------ test/test_array.py | 13 +++++++++++++ 4 files changed, 22 insertions(+), 26 deletions(-) diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index db4838dc..4c2f9a22 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -255,7 +255,7 @@ def _add_functionality(): "length of argument list do not agree") for i, (arg, arg_type_char) in enumerate( zip(args, arg_type_chars)): - if arg_type_char: + if arg_type_char and arg_type_char != "V": self.set_arg(i, pack(arg_type_char, arg)) else: self.set_arg(i, arg) diff --git a/pyopencl/compyte b/pyopencl/compyte index 37a5bfef..52aecae2 160000 --- a/pyopencl/compyte +++ b/pyopencl/compyte @@ -1 +1 @@ -Subproject commit 37a5bfef1a46b23bfca1817c6c375506ba09ba28 +Subproject commit 52aecae2c0019caa81342ab79b47f60601a6a1b1 diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 43a60224..2e150246 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -157,36 +157,23 @@ def dtype_to_ctype(dtype): # {{{ C argument lists -------------------------------------------------------- class Argument: - def __init__(self, dtype, name, vector_len=1): + def __init__(self, dtype, name): self.dtype = np.dtype(dtype) self.name = name - self.vector_len = vector_len def __repr__(self): return "%s(%r, %s, %d)" % ( self.__class__.__name__, self.name, - self.dtype, - self.vector_len) + self.dtype) class VectorArg(Argument): def declarator(self): - if self.vector_len == 1: - vlen_str = "" - else: - vlen_str = str(self.vector_len) - - return "__global %s%s *%s" % (dtype_to_ctype(self.dtype), - vlen_str, self.name) + return "__global %s *%s" % (dtype_to_ctype(self.dtype), self.name) class ScalarArg(Argument): def declarator(self): - if self.vector_len == 1: - vlen_str = "" - else: - vlen_str = str(self.vector_len) - - return "%s%s %s" % (dtype_to_ctype(self.dtype), vlen_str, self.name) + return "%s %s" % (dtype_to_ctype(self.dtype), self.name) @@ -218,17 +205,13 @@ def parse_c_arg(c_arg): tp = c_arg[:decl_match.start()] tp = " ".join(tp.split()) - type_re = re.compile(r"^([a-z ]+)([0-9]*)$") + type_re = re.compile(r"^([a-z0-9 ]+)$") type_match = type_re.match(tp) if not type_match: raise RuntimeError("type '%s' did not match expected shape of type" % tp) tp = type_match.group(1) - if type_match.group(2): - vector_len = int(type_match.group(2)) - else: - vector_len = 1 if tp == "float": dtype = np.float32 elif tp == "double": dtype = np.float64 @@ -246,11 +229,11 @@ def parse_c_arg(c_arg): else: import pyopencl.array as cl_array try: - return cl_array.vec._c_name_to_dtype[tp] + dtype = cl_array.vec._c_name_to_dtype[tp] except KeyError: raise ValueError("unknown type '%s'" % tp) - return arg_class(dtype, name, vector_len) + return arg_class(dtype, name) # }}} diff --git a/test/test_array.py b/test/test_array.py index 72c98a89..9627b95f 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -605,6 +605,19 @@ def test_stride_preservation(ctx_getter): +@pytools.test.mark_test.opencl +def test_vector_fill(ctx_getter): + context = ctx_getter() + queue = cl.CommandQueue(context) + + a_gpu = cl_array.Array(queue, 100, dtype=cl_array.vec.float4) + a_gpu.fill(cl_array.vec.make_float4(0.0, 0.0, 1.0, 0.0)) + a = a_gpu.get() + assert a.dtype is cl_array.vec.float4 + + + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl as cl -- GitLab