From b330721dd37c755c68ad7653f64a19fc81227335 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 17 Aug 2016 14:34:21 -0500 Subject: [PATCH] Make array.zeros work for custom types --- pyopencl/array.py | 24 +++++++++++++++++++----- test/test_array.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index d3439739..f6f8a35a 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1042,13 +1042,29 @@ class Array(object): __rtruediv__ = __rdiv__ + def _zero_fill(self, queue=None, wait_for=None): + queue = queue or self.queue + + if ( + queue._get_cl_version() >= (1, 2) + and cl.get_cl_header_version() >= (1, 2)): + + self.add_event( + cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0), + self.offset, self.nbytes, wait_for=wait_for)) + else: + zero = np.zeros((), self.dtype) + self.fill(zero, queue=queue) + def fill(self, value, queue=None, wait_for=None): """Fill the array with *scalar*. :returns: *self*. """ + self.add_event( - self._fill(self, value, queue=queue, wait_for=wait_for)) + cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0), + self.offset, self.nbytes, wait_for=wait_for)) return self @@ -1771,8 +1787,7 @@ def zeros(queue, shape, dtype, order="C", allocator=None): result = Array(queue, shape, dtype, order=order, allocator=allocator) - zero = np.zeros((), dtype) - result.fill(zero) + result._zero_fill() return result @@ -1791,8 +1806,7 @@ def zeros_like(ary): """ result = empty_like(ary) - zero = np.zeros((), ary.dtype) - result.fill(zero) + result._zero_fill() return result diff --git a/test/test_array.py b/test/test_array.py index 7b48a954..1876b081 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -231,6 +231,35 @@ def test_absrealimag(ctx_factory): print(dev_res-host_res) assert correct + +def test_custom_type_zeros(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + if not ( + queue._get_cl_version() >= (1, 2) + and cl.get_cl_header_version() >= (1, 2)): + pytest.skip("CL1.2 not available") + + dtype = np.dtype([ + ("cur_min", np.int32), + ("cur_max", np.int32), + ("pad", np.int32), + ]) + + from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct + + name = "mmc_type" + dtype, c_decl = match_dtype_to_c_struct(queue.device, name, dtype) + dtype = get_or_register_dtype(name, dtype) + + n = 1000 + z_dev = cl.array.zeros(queue, n, dtype=dtype) + + z = z_dev.get() + + assert np.array_equal(np.zeros(n, dtype), z) + # }}} -- GitLab