From 99468f93d1cd1b73a39e31d6e52fea8f6c98cde9 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Fri, 21 Oct 2022 16:46:21 +0300
Subject: [PATCH] clean up docs for GenericScanKernel

---
 doc/algorithm.rst | 21 +--------------------
 pyopencl/scan.py  | 43 +++++++++++++++++++++++++++++++++++++++----
 2 files changed, 40 insertions(+), 24 deletions(-)

diff --git a/doc/algorithm.rst b/doc/algorithm.rst
index 4bf5aefb..db2a1838 100644
--- a/doc/algorithm.rst
+++ b/doc/algorithm.rst
@@ -129,29 +129,10 @@ Making Custom Scan Kernels
 
 .. autoclass:: GenericScanKernel
 
-    .. method:: __call__(*args, allocator=None, queue=None, size=None, wait_for=None)
-
-        *queue* and *allocator* default to the ones provided on the first
-        :class:`pyopencl.array.Array` in *args*. *size* may specify the
-        length of the scan to be carried out. If not given, this length
-        is inferred from the first array argument passed.
-
-        |std-enqueue-blurb|
-
-        .. note::
-
-            The returned :class:`pyopencl.Event` corresponds only to part of the
-            execution of the scan. It is not suitable for profiling.
-
 Debugging aids
 ~~~~~~~~~~~~~~
 
-.. class:: GenericDebugScanKernel
-
-    Performs the same function and has the same interface as
-    :class:`GenericScanKernel`, but uses a dead-simple, sequential scan.  Works
-    best on CPU platforms, and helps isolate bugs in scans by removing the
-    potential for issues originating in parallel execution.
+.. autoclass:: GenericDebugScanKernel
 
 .. _predefined-scans:
 
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index f2c42bc2..f330f042 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -923,7 +923,7 @@ class ScanPerformanceWarning(UserWarning):
     pass
 
 
-class _GenericScanKernelBase(ABC):
+class GenericScanKernelBase(ABC):
     # {{{ constructor, argument processing
 
     def __init__(
@@ -937,7 +937,7 @@ class _GenericScanKernelBase(ABC):
             output_statement: str,
             is_segment_start_expr: Optional[str] = None,
             input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
-            index_dtype: Any = np.int32,
+            index_dtype: Any = None,
             name_prefix: str = "scan",
             options: Any = None,
             preamble: str = "",
@@ -1024,6 +1024,9 @@ class _GenericScanKernelBase(ABC):
         being processed in the scan.
         """
 
+        if index_dtype is None:
+            index_dtype = np.dtype(np.int32)
+
         if input_fetch_exprs is None:
             input_fetch_exprs = []
 
@@ -1142,7 +1145,7 @@ generic_scan_kernel_cache = WriteOncePersistentDict(
         key_builder=_NumpyTypesKeyBuilder())
 
 
-class GenericScanKernel(_GenericScanKernelBase):
+class GenericScanKernel(GenericScanKernelBase):
     """Generates and executes code that performs prefix sums ("scans") on
     arbitrary types, with many possible tweaks.
 
@@ -1158,6 +1161,9 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         a = cl.array.arange(queue, 10000, dtype=np.int32)
         knl(a, queue=queue)
+
+    .. automethod:: __init__
+    .. automethod:: __call__
     """
 
     def finish_setup(self) -> None:
@@ -1465,6 +1471,24 @@ class GenericScanKernel(_GenericScanKernelBase):
     # }}}
 
     def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
+        """
+        |std-enqueue-blurb|
+
+        .. note::
+
+            The returned :class:`pyopencl.Event` corresponds only to part of the
+            execution of the scan. It is not suitable for profiling.
+
+        :arg queue: queue on which to execute the scan. If not given, the
+            queue of the first :class:`pyopencl.array.Array` in *args* is used
+        :arg allocator: an allocator for the temporary arrays and results. If
+            not given, the allocator of the first :class:`pyopencl.array.Array`
+            in *args* is used.
+        :arg size: specify the length of the scan to be carried out. If not
+            given, this length is inferred from the first argument
+        :arg wait_for: a :class:`list` of events to wait for.
+        """
+
         # {{{ argument processing
 
         allocator = kwargs.get("allocator")
@@ -1657,7 +1681,16 @@ void ${name_prefix}_debug_scan(
 """
 
 
-class GenericDebugScanKernel(_GenericScanKernelBase):
+class GenericDebugScanKernel(GenericScanKernelBase):
+    """
+    Performs the same function and has the same interface as
+    :class:`GenericScanKernel`, but uses a dead-simple, sequential scan.  Works
+    best on CPU platforms, and helps isolate bugs in scans by removing the
+    potential for issues originating in parallel execution.
+
+    .. automethod:: __call__
+    """
+
     def finish_setup(self) -> None:
         scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
         scan_src = str(scan_tpl.render(
@@ -1680,6 +1713,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
         self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
 
     def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
+        """See :meth:`GenericScanKernel.__call__`."""
+
         # {{{ argument processing
 
         allocator = kwargs.get("allocator")
-- 
GitLab