From bb34f9058d453d0507726ed4e3c23dcb1be6be4c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 9 Aug 2018 19:46:30 -0500
Subject: [PATCH] [pybind11] basic demo executes again

---
 pyopencl/__init__.py   | 306 +++++------------------------------------
 pyopencl/invoker.py    |  30 ++--
 src/wrap_cl_part_1.cpp |   1 +
 src/wrap_cl_part_2.cpp |   2 +
 4 files changed, 50 insertions(+), 289 deletions(-)

diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py
index f8ba8ccb..10c33c73 100644
--- a/pyopencl/__init__.py
+++ b/pyopencl/__init__.py
@@ -195,10 +195,18 @@ from pyopencl._cl import (  # noqa
 
 import inspect as _inspect
 
-CONSTANT_CLASSES = [
+CONSTANT_CLASSES = tuple(
         getattr(_cl, name) for name in dir(_cl)
         if _inspect.isclass(getattr(_cl, name))
-        and name[0].islower() and name not in ["zip", "map", "range"]]
+        and name[0].islower() and name not in ["zip", "map", "range"])
+
+_KERNEL_ARG_CLASSES = (
+        MemoryObjectHolder,
+        Sampler,
+        LocalMemory,
+        # FIXME
+        # SVM,
+        )
 
 
 if _cl.have_gl():
@@ -806,161 +814,18 @@ def _add_functionality():
     def kernel__setup(self, prg):
         self._source = getattr(prg, "_source", None)
 
-        self._generate_naive_call()
+        from pyopencl.invoker import generate_enqueue_and_set_args
+        self._enqueue, self._set_args = generate_enqueue_and_set_args(
+                self.function_name, self.num_args, self.num_args,
+                None,
+                warn_about_arg_count_bug=None,
+                work_around_arg_count_bug=None)
+
         self._wg_info_cache = {}
         return self
 
-    def kernel_get_work_group_info(self, param, device):
-        try:
-            return self._wg_info_cache[param, device]
-        except KeyError:
-            pass
-
-        result = kernel_old_get_work_group_info(self, param, device)
-        self._wg_info_cache[param, device] = result
-        return result
-
-    # {{{ code generation for __call__, set_args
-
-    def kernel__set_set_args_body(self, body, num_passed_args):
-        from pytools.py_codegen import (
-                PythonFunctionGenerator,
-                PythonCodeGenerator,
-                Indentation)
-
-        arg_names = ["arg%d" % i for i in range(num_passed_args)]
-
-        # {{{ wrap in error handler
-
-        err_gen = PythonCodeGenerator()
-
-        def gen_error_handler():
-            err_gen("""
-                if current_arg is not None:
-                    args = [{args}]
-                    advice = ""
-                    from pyopencl.array import Array
-                    if isinstance(args[current_arg], Array):
-                        advice = " (perhaps you meant to pass 'array.data' " \
-                            "instead of the array itself?)"
-
-                    raise _cl.LogicError(
-                            "when processing argument #%d (1-based): %s%s"
-                            % (current_arg+1, str(e), advice))
-                else:
-                    raise
-                """
-                .format(args=", ".join(arg_names)))
-            err_gen("")
-
-        err_gen("try:")
-        with Indentation(err_gen):
-            err_gen.extend(body)
-        err_gen("except TypeError as e:")
-        with Indentation(err_gen):
-            gen_error_handler()
-        err_gen("except _cl.LogicError as e:")
-        with Indentation(err_gen):
-            gen_error_handler()
-
-        # }}}
-
-        def add_preamble(gen):
-            gen.add_to_preamble(
-                "import numpy as np")
-            gen.add_to_preamble(
-                "import pyopencl._cl as _cl")
-            gen.add_to_preamble("from pyopencl import status_code")
-            gen.add_to_preamble("from struct import pack")
-            gen.add_to_preamble("")
-
-        # {{{ generate _enqueue
-
-        gen = PythonFunctionGenerator("enqueue_knl_%s" % self.function_name,
-                ["self", "queue", "global_size", "local_size"]
-                + arg_names
-                + ["global_offset=None", "g_times_l=None", "wait_for=None"])
-
-        add_preamble(gen)
-        gen.extend(err_gen)
-
-        gen("""
-            return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size,
-                    global_offset, wait_for, g_times_l=g_times_l)
-            """)
-
-        self._enqueue = gen.get_function()
-
-        # }}}
-
-        # {{{ generate set_args
-
-        gen = PythonFunctionGenerator("_set_args", ["self"] + arg_names)
-
-        add_preamble(gen)
-        gen.extend(err_gen)
-
-        self._set_args = gen.get_function()
-
-        # }}}
-
-    def kernel__generate_buffer_arg_setter(self, gen, arg_idx, buf_var):
-        # (TODO: still needed?)
-
-        # from pytools.py_codegen import Indentation
-        #
-        # if _CPY2:
-        #     # https://github.com/numpy/numpy/issues/5381
-        #     gen("if isinstance({buf_var}, np.generic):".format(buf_var=buf_var))
-        #     with Indentation(gen):
-        #         gen("{buf_var} = np.getbuffer({buf_var})".format(buf_var=buf_var))
-
-        gen("""
-            kernel._set_arg_bytes({arg_idx}, {buf_var})
-            """
-            .format(arg_idx=arg_idx, buf_var=buf_var))
-
-    def kernel__generate_bytes_arg_setter(self, gen, arg_idx, buf_var):
-        gen("""
-            self._set_arg_bytes({arg_idx}, {buf_var})
-            """
-            .format(arg_idx=arg_idx, buf_var=buf_var))
-
-    def kernel__generate_generic_arg_handler(self, gen, arg_idx, arg_var):
-        from pytools.py_codegen import Indentation
-
-        gen("""
-            if {arg_var} is None:
-                self._set_arg_null({arg_idx})
-            elif isinstance({arg_var}, _CLKernelArg):
-                self.set_arg({arg_idx}, {arg_var})
-            """
-            .format(arg_idx=arg_idx, arg_var=arg_var))
-
-        gen("else:")
-        with Indentation(gen):
-            self._generate_buffer_arg_setter(gen, arg_idx, arg_var)
-
-    def kernel__generate_naive_call(self):
-        num_args = self.num_args
-
-        from pytools.py_codegen import PythonCodeGenerator
-        gen = PythonCodeGenerator()
-
-        if num_args == 0:
-            gen("pass")
-
-        for i in range(num_args):
-            gen("# process argument {arg_idx}".format(arg_idx=i))
-            gen("")
-            gen("current_arg = {arg_idx}".format(arg_idx=i))
-            self._generate_generic_arg_handler(gen, i, "arg%d" % i)
-            gen("")
-
-        self._set_set_args_body(gen, num_args)
-
     def kernel_set_scalar_arg_dtypes(self, scalar_arg_dtypes):
-        self._scalar_arg_dtypes = scalar_arg_dtypes
+        self._scalar_arg_dtypes = tuple(scalar_arg_dtypes)
 
         # {{{ arg counting bug handling
 
@@ -974,7 +839,7 @@ def _add_functionality():
         from pyopencl.characterize import has_struct_arg_count_bug
 
         count_bug_per_dev = [
-                has_struct_arg_count_bug(dev)
+                has_struct_arg_count_bug(dev, self.context)
                 for dev in self.context.devices]
 
         from pytools import single_valued
@@ -984,119 +849,25 @@ def _add_functionality():
             else:
                 warn_about_arg_count_bug = True
 
-        fp_arg_count = 0
-
         # }}}
 
-        cl_arg_idx = 0
-
-        from pytools.py_codegen import PythonCodeGenerator
-        gen = PythonCodeGenerator()
-
-        if not scalar_arg_dtypes:
-            gen("pass")
-
-        for arg_idx, arg_dtype in enumerate(scalar_arg_dtypes):
-            gen("# process argument {arg_idx}".format(arg_idx=arg_idx))
-            gen("")
-            gen("current_arg = {arg_idx}".format(arg_idx=arg_idx))
-            arg_var = "arg%d" % arg_idx
-
-            if arg_dtype is None:
-                self._generate_generic_arg_handler(gen, cl_arg_idx, arg_var)
-                cl_arg_idx += 1
-                gen("")
-                continue
-
-            arg_dtype = np.dtype(arg_dtype)
-
-            if arg_dtype.char == "V":
-                self._generate_generic_arg_handler(gen, cl_arg_idx, arg_var)
-                cl_arg_idx += 1
-
-            elif arg_dtype.kind == "c":
-                if warn_about_arg_count_bug:
-                    warn("{knl_name}: arguments include complex numbers, and "
-                            "some (but not all) of the target devices mishandle "
-                            "struct kernel arguments (hence the workaround is "
-                            "disabled".format(
-                                knl_name=self.function_name, stacklevel=2))
-
-                if arg_dtype == np.complex64:
-                    arg_char = "f"
-                elif arg_dtype == np.complex128:
-                    arg_char = "d"
-                else:
-                    raise TypeError("unexpected complex type: %s" % arg_dtype)
-
-                if (work_around_arg_count_bug == "pocl"
-                        and arg_dtype == np.complex128
-                        and fp_arg_count + 2 <= 8):
-                    gen(
-                            "buf = pack('{arg_char}', {arg_var}.real)"
-                            .format(arg_char=arg_char, arg_var=arg_var))
-                    self._generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
-                    cl_arg_idx += 1
-                    gen("current_arg = current_arg + 1000")
-                    gen(
-                            "buf = pack('{arg_char}', {arg_var}.imag)"
-                            .format(arg_char=arg_char, arg_var=arg_var))
-                    self._generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
-                    cl_arg_idx += 1
-
-                elif (work_around_arg_count_bug == "apple"
-                        and arg_dtype == np.complex128
-                        and fp_arg_count + 2 <= 8):
-                    raise NotImplementedError("No work-around to "
-                            "Apple's broken structs-as-kernel arg "
-                            "handling has been found. "
-                            "Cannot pass complex numbers to kernels.")
-
-                else:
-                    gen(
-                            "buf = pack('{arg_char}{arg_char}', "
-                            "{arg_var}.real, {arg_var}.imag)"
-                            .format(arg_char=arg_char, arg_var=arg_var))
-                    self._generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
-                    cl_arg_idx += 1
-
-                fp_arg_count += 2
-
-            elif arg_dtype.char in "IL" and _CPY26:
-                # Prevent SystemError: ../Objects/longobject.c:336: bad
-                # argument to internal function
-
-                gen(
-                        "buf = pack('{arg_char}', long({arg_var}))"
-                        .format(arg_char=arg_dtype.char, arg_var=arg_var))
-                self._generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
-                cl_arg_idx += 1
-
-            else:
-                if arg_dtype.kind == "f":
-                    fp_arg_count += 1
-
-                arg_char = arg_dtype.char
-                arg_char = _type_char_map.get(arg_char, arg_char)
-                gen(
-                        "buf = pack('{arg_char}', {arg_var})"
-                        .format(
-                            arg_char=arg_char,
-                            arg_var=arg_var))
-                self._generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
-                cl_arg_idx += 1
-
-            gen("")
-
-        if cl_arg_idx != self.num_args:
-            raise TypeError(
-                "length of argument list (%d) and "
-                "CL-generated number of arguments (%d) do not agree"
-                % (cl_arg_idx, self.num_args))
+        from pyopencl.invoker import generate_enqueue_and_set_args
+        self._enqueue, self._set_args = generate_enqueue_and_set_args(
+                self.function_name,
+                len(scalar_arg_dtypes), self.num_args,
+                self._scalar_arg_dtypes,
+                warn_about_arg_count_bug=warn_about_arg_count_bug,
+                work_around_arg_count_bug=work_around_arg_count_bug)
 
-        self._set_set_args_body(gen, len(scalar_arg_dtypes))
+    def kernel_get_work_group_info(self, param, device):
+        try:
+            return self._wg_info_cache[param, device]
+        except KeyError:
+            pass
 
-    # }}}
+        result = kernel_old_get_work_group_info(self, param, device)
+        self._wg_info_cache[param, device] = result
+        return result
 
     def kernel_set_args(self, *args, **kwargs):
         # Need to dupicate the 'self' argument for dynamically generated  method
@@ -1116,11 +887,6 @@ def _add_functionality():
     Kernel.__init__ = kernel_init
     Kernel._setup = kernel__setup
     Kernel.get_work_group_info = kernel_get_work_group_info
-    Kernel._set_set_args_body = kernel__set_set_args_body
-    Kernel._generate_bufprot_arg_setter = kernel__generate_bufprot_arg_setter
-    Kernel._generate_bytes_arg_setter = kernel__generate_bytes_arg_setter
-    Kernel._generate_generic_arg_handler = kernel__generate_generic_arg_handler
-    Kernel._generate_naive_call = kernel__generate_naive_call
     Kernel.set_scalar_arg_dtypes = kernel_set_scalar_arg_dtypes
     Kernel.set_args = kernel_set_args
     Kernel.__call__ = kernel_call
@@ -1673,7 +1439,9 @@ def enqueue_copy(queue, dest, src, **kwargs):
         else:
             raise ValueError("invalid dest mem object type")
 
-    elif isinstance(dest, SVM):
+    # FIXME
+    # elif isinstance(dest, SVM):
+    elif 0:
         # to SVM
         if isinstance(src, SVM):
             src = src.mem
diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py
index 8cad3f25..7fad942c 100644
--- a/pyopencl/invoker.py
+++ b/pyopencl/invoker.py
@@ -28,7 +28,7 @@ import sys
 import numpy as np
 
 from warnings import warn
-from pyopencl._cffi import ffi as _ffi
+import pyopencl._cl as _cl
 from pytools.persistent_dict import WriteOncePersistentDict
 from pyopencl.tools import _NumpyTypesKeyBuilder
 
@@ -44,7 +44,7 @@ _size_t_char = ({
     4: 'L',
     2: 'H',
     1: 'B',
-})[_ffi.sizeof('size_t')]
+})[_cl._sizeof_size_t()]
 _type_char_map = {
     'n': _size_t_char.lower(),
     'N': _size_t_char
@@ -66,20 +66,14 @@ def generate_buffer_arg_setter(gen, arg_idx, buf_var):
             gen("{buf_var} = np.getbuffer({buf_var})".format(buf_var=buf_var))
 
     gen("""
-        c_buf, sz, _ = _cl._c_buffer_from_obj({buf_var})
-        status = _lib.kernel__set_arg_buf(self.ptr, {arg_idx}, c_buf, sz)
-        if status != _ffi.NULL:
-            _handle_error(status)
+        self._set_arg_buf({arg_idx}, {buf_var})
         """
         .format(arg_idx=arg_idx, buf_var=buf_var))
 
 
 def generate_bytes_arg_setter(gen, arg_idx, buf_var):
     gen("""
-        status = _lib.kernel__set_arg_buf(self.ptr, {arg_idx},
-            {buf_var}, len({buf_var}))
-        if status != _ffi.NULL:
-            _handle_error(status)
+        self._set_arg_buf({arg_idx}, {buf_var})
         """
         .format(arg_idx=arg_idx, buf_var=buf_var))
 
@@ -89,11 +83,9 @@ def generate_generic_arg_handler(gen, arg_idx, arg_var):
 
     gen("""
         if {arg_var} is None:
-            status = _lib.kernel__set_arg_null(self.ptr, {arg_idx})
-            if status != _ffi.NULL:
-                _handle_error(status)
-        elif isinstance({arg_var}, _cl._CLKernelArg):
-            self._set_arg_clkernelarg({arg_idx}, {arg_var})
+            self._set_arg_null({arg_idx})
+        elif isinstance({arg_var}, _KERNEL_ARG_CLASSES):
+            self.set_arg({arg_idx}, {arg_var})
         """
         .format(arg_idx=arg_idx, arg_var=arg_var))
 
@@ -289,10 +281,8 @@ def wrap_in_error_handler(body, arg_names):
 
 def add_local_imports(gen):
     gen("import numpy as np")
-    gen("import pyopencl.cffi_cl as _cl")
-    gen(
-        "from pyopencl.cffi_cl import _lib, "
-        "_ffi, _handle_error, _CLKernelArg")
+    gen("import pyopencl._cl as _cl")
+    gen("from pyopencl import _KERNEL_ARG_CLASSES")
     gen("")
 
 
@@ -359,7 +349,7 @@ def _generate_enqueue_and_set_args_module(function_name,
 
 
 invoker_cache = WriteOncePersistentDict(
-        "pyopencl-invoker-cache-v1",
+        "pyopencl-invoker-cache-v4",
         key_builder=_NumpyTypesKeyBuilder())
 
 
diff --git a/src/wrap_cl_part_1.cpp b/src/wrap_cl_part_1.cpp
index 5c7bf70b..45633bd2 100644
--- a/src/wrap_cl_part_1.cpp
+++ b/src/wrap_cl_part_1.cpp
@@ -7,6 +7,7 @@ using namespace pyopencl;
 void pyopencl_expose_part_1(py::module &m)
 {
   m.def("get_cl_header_version", get_cl_header_version);
+  m.def("_sizeof_size_t", [](){ return sizeof(size_t); });
 
   // {{{ platform
   DEF_SIMPLE_FUNCTION(get_platforms);
diff --git a/src/wrap_cl_part_2.cpp b/src/wrap_cl_part_2.cpp
index f51e8a7b..13670472 100644
--- a/src/wrap_cl_part_2.cpp
+++ b/src/wrap_cl_part_2.cpp
@@ -320,6 +320,8 @@ void pyopencl_expose_part_2(py::module &m)
       .def(py::init<const program &, std::string const &>())
       .DEF_SIMPLE_METHOD(get_info)
       .DEF_SIMPLE_METHOD(get_work_group_info)
+      .def("_set_arg_null", &cls::set_arg_null)
+      .def("_set_arg_buf", &cls::set_arg_buf)
       .DEF_SIMPLE_METHOD(set_arg)
 #if PYOPENCL_CL_VERSION >= 0x1020
       .DEF_SIMPLE_METHOD(get_arg_info)
-- 
GitLab