diff --git a/pyopencl/array.py b/pyopencl/array.py index d3439739fd3d1a5b9753d2f6a62669b3db4fb269..f6f8a35a9135f00d06e22e3c821e6a854f2b9b23 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1042,13 +1042,29 @@ class Array(object): __rtruediv__ = __rdiv__ + def _zero_fill(self, queue=None, wait_for=None): + queue = queue or self.queue + + if ( + queue._get_cl_version() >= (1, 2) + and cl.get_cl_header_version() >= (1, 2)): + + self.add_event( + cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0), + self.offset, self.nbytes, wait_for=wait_for)) + else: + zero = np.zeros((), self.dtype) + self.fill(zero, queue=queue) + def fill(self, value, queue=None, wait_for=None): """Fill the array with *scalar*. :returns: *self*. """ + self.add_event( - self._fill(self, value, queue=queue, wait_for=wait_for)) + cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0), + self.offset, self.nbytes, wait_for=wait_for)) return self @@ -1771,8 +1787,7 @@ def zeros(queue, shape, dtype, order="C", allocator=None): result = Array(queue, shape, dtype, order=order, allocator=allocator) - zero = np.zeros((), dtype) - result.fill(zero) + result._zero_fill() return result @@ -1791,8 +1806,7 @@ def zeros_like(ary): """ result = empty_like(ary) - zero = np.zeros((), ary.dtype) - result.fill(zero) + result._zero_fill() return result diff --git a/test/test_array.py b/test/test_array.py index 7b48a95410638d3ea0ab7d0fa9925de836add785..1876b081f4977d9935076ab49c6a69a9f345e65b 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -231,6 +231,35 @@ def test_absrealimag(ctx_factory): print(dev_res-host_res) assert correct + +def test_custom_type_zeros(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + if not ( + queue._get_cl_version() >= (1, 2) + and cl.get_cl_header_version() >= (1, 2)): + pytest.skip("CL1.2 not available") + + dtype = np.dtype([ + ("cur_min", np.int32), + ("cur_max", np.int32), + ("pad", np.int32), + ]) + + from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct + + name = "mmc_type" + dtype, c_decl = match_dtype_to_c_struct(queue.device, name, dtype) + dtype = get_or_register_dtype(name, dtype) + + n = 1000 + z_dev = cl.array.zeros(queue, n, dtype=dtype) + + z = z_dev.get() + + assert np.array_equal(np.zeros(n, dtype), z) + # }}}