From 6a25ac0fd99c706e90814ebc62fe70cd98e99177 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 18 Aug 2022 16:56:45 -0500
Subject: [PATCH] Add, use enqueue_fill

---
 doc/runtime_memory.rst |  2 ++
 pyopencl/__init__.py   | 22 ++++++++++++++++++++++
 pyopencl/array.py      |  4 ++--
 3 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/doc/runtime_memory.rst b/doc/runtime_memory.rst
index cfe41dc5..f9c31a8b 100644
--- a/doc/runtime_memory.rst
+++ b/doc/runtime_memory.rst
@@ -371,6 +371,8 @@ Transfers
 
 .. autofunction:: enqueue_copy(queue, dest, src, **kwargs)
 
+.. autofunction:: enqueue_fill(queue, dest, src, **kwargs)
+
 Mapping Memory into Host Address Space
 --------------------------------------
 
diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py
index ab042c0f..94c35cc5 100644
--- a/pyopencl/__init__.py
+++ b/pyopencl/__init__.py
@@ -1962,6 +1962,28 @@ def enqueue_copy(queue, dest, src, **kwargs):
 # }}}
 
 
+# {{{ enqueue_fill
+
+def enqueue_fill(queue: CommandQueue,
+        dest: "Union[MemoryObjectHolder, SVMPointer]",
+        pattern: Any, size: int, *, offset: int = 0,
+        wait_for: Optional[Sequence[Event]] = None) -> Event:
+    """
+    .. versionadded:: 2022.2
+    """
+    if isinstance(dest, MemoryObjectHolder):
+        return enqueue_fill_buffer(queue, dest, pattern, offset, size, wait_for)
+    elif isinstance(dest, SVMPointer):
+        if offset:
+            raise NotImplementedError("enqueue_fill with SVM does not yet support "
+                    "offsets")
+        return enqueue_svm_memfill(queue, dest, pattern, size, wait_for)
+    else:
+        raise TypeError(f"enqueue_fill does not know how to fill '{type(dest)}'")
+
+# }}}
+
+
 # {{{ image creation
 
 DTYPE_TO_CHANNEL_TYPE = {
diff --git a/pyopencl/array.py b/pyopencl/array.py
index 80b1c61d..a1bf58e9 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -1469,8 +1469,8 @@ class Array:
         # https://github.com/inducer/pyopencl/issues/395
         if cl_version_gtr_1_2 and not (on_nvidia and self.nbytes >= 2**31):
             self.add_event(
-                    cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0),
-                        self.offset, self.nbytes, wait_for=wait_for))
+                    cl.enqueue_fill(queue, self.base_data, np.int8(0),
+                        self.nbytes, offset=self.offset, wait_for=wait_for))
         else:
             zero = np.zeros((), self.dtype)
             self.fill(zero, queue=queue)
-- 
GitLab