From b330721dd37c755c68ad7653f64a19fc81227335 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 17 Aug 2016 14:34:21 -0500
Subject: [PATCH] Make array.zeros work for custom types

---
 pyopencl/array.py  | 24 +++++++++++++++++++-----
 test/test_array.py | 29 +++++++++++++++++++++++++++++
 2 files changed, 48 insertions(+), 5 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index d3439739..f6f8a35a 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -1042,13 +1042,29 @@ class Array(object):
 
     __rtruediv__ = __rdiv__
 
+    def _zero_fill(self, queue=None, wait_for=None):
+        queue = queue or self.queue
+
+        if (
+                queue._get_cl_version() >= (1, 2)
+                and cl.get_cl_header_version() >= (1, 2)):
+
+            self.add_event(
+                    cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0),
+                        self.offset, self.nbytes, wait_for=wait_for))
+        else:
+            zero = np.zeros((), self.dtype)
+            self.fill(zero, queue=queue)
+
     def fill(self, value, queue=None, wait_for=None):
         """Fill the array with *scalar*.
 
         :returns: *self*.
         """
+
         self.add_event(
-                self._fill(self, value, queue=queue, wait_for=wait_for))
+                cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0),
+                    self.offset, self.nbytes, wait_for=wait_for))
 
         return self
 
@@ -1771,8 +1787,7 @@ def zeros(queue, shape, dtype, order="C", allocator=None):
 
     result = Array(queue, shape, dtype,
             order=order, allocator=allocator)
-    zero = np.zeros((), dtype)
-    result.fill(zero)
+    result._zero_fill()
     return result
 
 
@@ -1791,8 +1806,7 @@ def zeros_like(ary):
     """
 
     result = empty_like(ary)
-    zero = np.zeros((), ary.dtype)
-    result.fill(zero)
+    result._zero_fill()
     return result
 
 
diff --git a/test/test_array.py b/test/test_array.py
index 7b48a954..1876b081 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -231,6 +231,35 @@ def test_absrealimag(ctx_factory):
                 print(dev_res-host_res)
             assert correct
 
+
+def test_custom_type_zeros(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    if not (
+            queue._get_cl_version() >= (1, 2)
+            and cl.get_cl_header_version() >= (1, 2)):
+        pytest.skip("CL1.2 not available")
+
+    dtype = np.dtype([
+        ("cur_min", np.int32),
+        ("cur_max", np.int32),
+        ("pad", np.int32),
+        ])
+
+    from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct
+
+    name = "mmc_type"
+    dtype, c_decl = match_dtype_to_c_struct(queue.device, name, dtype)
+    dtype = get_or_register_dtype(name, dtype)
+
+    n = 1000
+    z_dev = cl.array.zeros(queue, n, dtype=dtype)
+
+    z = z_dev.get()
+
+    assert np.array_equal(np.zeros(n, dtype), z)
+
 # }}}
 
 
-- 
GitLab