Skip to content
Snippets Groups Projects
Commit 066c2598 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix interoperability of Array with vector types.

parent b45b94c8
No related branches found
No related tags found
No related merge requests found
...@@ -255,7 +255,7 @@ def _add_functionality(): ...@@ -255,7 +255,7 @@ def _add_functionality():
"length of argument list do not agree") "length of argument list do not agree")
for i, (arg, arg_type_char) in enumerate( for i, (arg, arg_type_char) in enumerate(
zip(args, arg_type_chars)): zip(args, arg_type_chars)):
if arg_type_char: if arg_type_char and arg_type_char != "V":
self.set_arg(i, pack(arg_type_char, arg)) self.set_arg(i, pack(arg_type_char, arg))
else: else:
self.set_arg(i, arg) self.set_arg(i, arg)
......
compyte @ 52aecae2
Subproject commit 37a5bfef1a46b23bfca1817c6c375506ba09ba28 Subproject commit 52aecae2c0019caa81342ab79b47f60601a6a1b1
...@@ -157,36 +157,23 @@ def dtype_to_ctype(dtype): ...@@ -157,36 +157,23 @@ def dtype_to_ctype(dtype):
# {{{ C argument lists -------------------------------------------------------- # {{{ C argument lists --------------------------------------------------------
class Argument: class Argument:
def __init__(self, dtype, name, vector_len=1): def __init__(self, dtype, name):
self.dtype = np.dtype(dtype) self.dtype = np.dtype(dtype)
self.name = name self.name = name
self.vector_len = vector_len
def __repr__(self): def __repr__(self):
return "%s(%r, %s, %d)" % ( return "%s(%r, %s, %d)" % (
self.__class__.__name__, self.__class__.__name__,
self.name, self.name,
self.dtype, self.dtype)
self.vector_len)
class VectorArg(Argument): class VectorArg(Argument):
def declarator(self): def declarator(self):
if self.vector_len == 1: return "__global %s *%s" % (dtype_to_ctype(self.dtype), self.name)
vlen_str = ""
else:
vlen_str = str(self.vector_len)
return "__global %s%s *%s" % (dtype_to_ctype(self.dtype),
vlen_str, self.name)
class ScalarArg(Argument): class ScalarArg(Argument):
def declarator(self): def declarator(self):
if self.vector_len == 1: return "%s %s" % (dtype_to_ctype(self.dtype), self.name)
vlen_str = ""
else:
vlen_str = str(self.vector_len)
return "%s%s %s" % (dtype_to_ctype(self.dtype), vlen_str, self.name)
...@@ -218,17 +205,13 @@ def parse_c_arg(c_arg): ...@@ -218,17 +205,13 @@ def parse_c_arg(c_arg):
tp = c_arg[:decl_match.start()] tp = c_arg[:decl_match.start()]
tp = " ".join(tp.split()) tp = " ".join(tp.split())
type_re = re.compile(r"^([a-z ]+)([0-9]*)$") type_re = re.compile(r"^([a-z0-9 ]+)$")
type_match = type_re.match(tp) type_match = type_re.match(tp)
if not type_match: if not type_match:
raise RuntimeError("type '%s' did not match expected shape of type" raise RuntimeError("type '%s' did not match expected shape of type"
% tp) % tp)
tp = type_match.group(1) tp = type_match.group(1)
if type_match.group(2):
vector_len = int(type_match.group(2))
else:
vector_len = 1
if tp == "float": dtype = np.float32 if tp == "float": dtype = np.float32
elif tp == "double": dtype = np.float64 elif tp == "double": dtype = np.float64
...@@ -246,11 +229,11 @@ def parse_c_arg(c_arg): ...@@ -246,11 +229,11 @@ def parse_c_arg(c_arg):
else: else:
import pyopencl.array as cl_array import pyopencl.array as cl_array
try: try:
return cl_array.vec._c_name_to_dtype[tp] dtype = cl_array.vec._c_name_to_dtype[tp]
except KeyError: except KeyError:
raise ValueError("unknown type '%s'" % tp) raise ValueError("unknown type '%s'" % tp)
return arg_class(dtype, name, vector_len) return arg_class(dtype, name)
# }}} # }}}
......
...@@ -605,6 +605,19 @@ def test_stride_preservation(ctx_getter): ...@@ -605,6 +605,19 @@ def test_stride_preservation(ctx_getter):
@pytools.test.mark_test.opencl
def test_vector_fill(ctx_getter):
context = ctx_getter()
queue = cl.CommandQueue(context)
a_gpu = cl_array.Array(queue, 100, dtype=cl_array.vec.float4)
a_gpu.fill(cl_array.vec.make_float4(0.0, 0.0, 1.0, 0.0))
a = a_gpu.get()
assert a.dtype is cl_array.vec.float4
if __name__ == "__main__": if __name__ == "__main__":
# make sure that import failures get reported, instead of skipping the tests. # make sure that import failures get reported, instead of skipping the tests.
import pyopencl as cl import pyopencl as cl
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment