From bd750fa0b0816f07e5547d914432b54f59ef5d9f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 2 Aug 2012 01:39:06 -0400
Subject: [PATCH] Add support for ranged elementwise kernels.

---
 pyopencl/elementwise.py | 183 +++++++++++++++++++++++++++++-----------
 test/test_array.py      |  24 ++++++
 2 files changed, 158 insertions(+), 49 deletions(-)

diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index ed755d13..f3bdf3c7 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -31,77 +31,117 @@ OTHER DEALINGS IN THE SOFTWARE.
 from pyopencl.tools import context_dependent_memoize
 import numpy as np
 import pyopencl as cl
-from pyopencl.tools import dtype_to_ctype, VectorArg, ScalarArg
+from pytools import memoize_method
+from pyopencl.tools import (dtype_to_ctype, VectorArg, ScalarArg,
+        KernelTemplateBase)
 
 
+# {{{ elementwise kernel code generator
+
 def get_elwise_program(context, arguments, operation,
         name="elwise_kernel", options=[],
-        preamble="", loop_prep="", after_loop=""):
-    from pyopencl import Program
-    source = ("""
-        %(preamble)s
-
-        __kernel void %(name)s(%(arguments)s)
-        {
-          unsigned lid = get_local_id(0);
-          unsigned gsize = get_global_size(0);
-          unsigned work_item_start = get_local_size(0)*get_group_id(0);
-          unsigned i;
-
-          %(loop_prep)s;
+        preamble="", loop_prep="", after_loop="",
+        use_range=False):
 
+    if use_range:
+        body = r"""//CL//
+          if (step < 0)
+          {
+            for (i = start + (work_item_start + lid)*step;
+              i > stop; i += gsize*step)
+            {
+              %(operation)s;
+            }
+          }
+          else
+          {
+            for (i = start + (work_item_start + lid)*step;
+              i < stop; i += gsize*step)
+            {
+              %(operation)s;
+            }
+          }
+          """
+    else:
+        body = """//CL//
           for (i = work_item_start + lid; i < n; i += gsize)
           {
             %(operation)s;
           }
+          """
+
+    source = ("""//CL//
+        %(preamble)s
 
+        __kernel void %(name)s(%(arguments)s)
+        {
+          int lid = get_local_id(0);
+          int gsize = get_global_size(0);
+          int work_item_start = get_local_size(0)*get_group_id(0);
+          long i;
+
+          %(loop_prep)s;
+          %(body)s
           %(after_loop)s;
         }
         """ % {
             "arguments": ", ".join(arg.declarator() for arg in arguments),
-            "operation": operation,
             "name": name,
             "preamble": preamble,
             "loop_prep": loop_prep,
             "after_loop": after_loop,
+            "body": body %  dict(operation=operation),
             })
 
+    from pyopencl import Program
     return Program(context, source).build(options)
 
 
 def get_elwise_kernel_and_types(context, arguments, operation,
-        name="elwise_kernel", options=[], preamble="", **kwargs):
+        name="elwise_kernel", options=[], preamble="", use_range=False,
+        **kwargs):
     if isinstance(arguments, str):
         from pyopencl.tools import parse_c_arg
         parsed_args = [parse_c_arg(arg) for arg in arguments.split(",")]
     else:
         parsed_args = arguments
 
+    auto_preamble = kwargs.pop("auto_preamble", True)
+
     pragmas = []
     includes = []
     have_double_pragma = False
     have_complex_include = False
 
-    for arg in parsed_args:
-        if arg.dtype in [np.float64, np.complex128]:
-            if not have_double_pragma:
-                pragmas.append(
-                        "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n"
-                        "#define PYOPENCL_DEFINE_CDOUBLE\n")
-                have_double_pragma = True
-        if arg.dtype.kind == 'c':
-            if not have_complex_include:
-                includes.append("#include <pyopencl-complex.h>\n")
-                have_complex_include = True
+    if auto_preamble:
+        for arg in parsed_args:
+            if arg.dtype in [np.float64, np.complex128]:
+                if not have_double_pragma:
+                    pragmas.append(
+                            "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n"
+                            "#define PYOPENCL_DEFINE_CDOUBLE\n")
+                    have_double_pragma = True
+            if arg.dtype.kind == 'c':
+                if not have_complex_include:
+                    includes.append("#include <pyopencl-complex.h>\n")
+                    have_complex_include = True
 
     if pragmas or includes:
         preamble = "\n".join(pragmas+includes) + "\n" + preamble
 
-    parsed_args.append(ScalarArg(np.uintp, "n"))
+    if use_range:
+        parsed_args.extend([
+            ScalarArg(np.intp, "start"),
+            ScalarArg(np.intp, "stop"),
+            ScalarArg(np.intp, "step"),
+            ])
+    else:
+        parsed_args.append(ScalarArg(np.intp, "n"))
 
     prg = get_elwise_program(
         context, parsed_args, operation,
-        name=name, options=options, preamble=preamble, **kwargs)
+        name=name, options=options, preamble=preamble,
+        use_range=use_range, **kwargs)
 
     scalar_arg_dtypes = []
     for arg in parsed_args:
@@ -136,50 +176,95 @@ def get_elwise_kernel(context, arguments, operation,
 class ElementwiseKernel:
     def __init__(self, context, arguments, operation,
             name="elwise_kernel", options=[], **kwargs):
-
-        self.kernel, self.arguments = get_elwise_kernel_and_types(
-            context, arguments, operation,
-            name=name, options=options,
-            **kwargs)
-
-        if not [i for i, arg in enumerate(self.arguments)
+        self.context = context
+        self.arguments = arguments
+        self.operation = operation
+        self.name = name
+        self.options = options
+        self.kwargs = kwargs
+
+    @memoize_method
+    def get_kernel(self, use_range):
+        knl, arg_descrs = get_elwise_kernel_and_types(
+            self.context, self.arguments, self.operation,
+            name=self.name, options=self.options,
+            use_range=use_range, **self.kwargs)
+
+        if not [i for i, arg in enumerate(arg_descrs)
                 if isinstance(arg, VectorArg)]:
             raise RuntimeError(
                 "ElementwiseKernel can only be used with "
                 "functions that have at least one "
                 "vector argument")
+        return knl, arg_descrs
 
     def __call__(self, *args, **kwargs):
-        vectors = []
+        repr_vec = None
+
+        range_ = kwargs.pop("range", None)
+        slice_ = kwargs.pop("slice", None)
+
+        use_range = range_ is not None or slice_ is not None
+        kernel, arg_descrs = self.get_kernel(use_range)
+
+        # {{{ assemble arg array
 
         invocation_args = []
-        for arg, arg_descr in zip(args, self.arguments):
+        for arg, arg_descr in zip(args, arg_descrs):
             if isinstance(arg_descr, VectorArg):
                 if not arg.flags.forc:
                     raise RuntimeError("ElementwiseKernel cannot "
                             "deal with non-contiguous arrays")
 
-                vectors.append(arg)
+                if repr_vec is None:
+                    repr_vec = arg
+
                 invocation_args.append(arg.data)
             else:
                 invocation_args.append(arg)
 
+        # }}}
+
         queue = kwargs.pop("queue", None)
         wait_for = kwargs.pop("wait_for", None)
         if kwargs:
-            raise TypeError("too many/unknown keyword arguments")
+            raise TypeError("unknown keyword arguments: '%s'"
+                    % ", ".join(kwargs))
 
-        repr_vec = vectors[0]
         if queue is None:
             queue = repr_vec.queue
-        invocation_args.append(repr_vec.mem_size)
-
-        gs, ls = repr_vec.get_sizes(queue,
-                self.kernel.get_work_group_info(
-                    cl.kernel_work_group_info.WORK_GROUP_SIZE,
-                    queue.device))
-        self.kernel.set_args(*invocation_args)
-        return cl.enqueue_nd_range_kernel(queue, self.kernel,
+
+        if slice_ is not None:
+            if range_ is not None:
+                raise TypeError("may not specify both range and slice "
+                        "keyword arguments")
+
+            range_ = slice(*slice_.indices(repr_vec.size))
+
+        max_wg_size = kernel.get_work_group_info(
+                cl.kernel_work_group_info.WORK_GROUP_SIZE,
+                queue.device)
+
+        if range_ is not None:
+            invocation_args.append(range_.start)
+            invocation_args.append(range_.stop)
+            if range_.step is None:
+                step = 1
+            else:
+                step = range_.step
+
+            invocation_args.append(step)
+
+            from pyopencl.array import splay
+            gs, ls = splay(queue,
+                    abs(range_.stop - range_.start)//step,
+                    max_wg_size)
+        else:
+            invocation_args.append(repr_vec.mem_size)
+            gs, ls = repr_vec.get_sizes(queue, max_wg_size)
+
+        kernel.set_args(*invocation_args)
+        return cl.enqueue_nd_range_kernel(queue, kernel,
                 gs, ls, wait_for=wait_for)
 
 
diff --git a/test/test_array.py b/test/test_array.py
index cdc1cfef..3fc92bac 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -477,6 +477,30 @@ def test_elwise_kernel_with_options(ctx_factory):
     assert la.norm(gv - gt) < 1e-5
 
 
+@pytools.test.mark_test.opencl
+def test_ranged_elwise_kernel(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    from pyopencl.elementwise import ElementwiseKernel
+    set_to_seven = ElementwiseKernel(context,
+            "float *z", "z[i] = 7", "set_to_seven")
+
+    for i, slc in enumerate([
+            slice(5, 20000),
+            slice(5, 20000, 17),
+            slice(3000, 5, -1),
+            slice(1000, -1),
+            ]):
+
+        a_gpu = cl_array.zeros(queue, (50000,), dtype=np.float32)
+        a_cpu = np.zeros(a_gpu.shape, a_gpu.dtype)
+
+        a_cpu[slc] = 7
+        set_to_seven(a_gpu, slice=slc)
+
+        assert (a_cpu == a_gpu.get()).all()
+
 @pytools.test.mark_test.opencl
 def test_take(ctx_factory):
     context = ctx_factory()
-- 
GitLab