From 75e460b7405aa59560c0e9f2bb2497894cd7a584 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 17 Jan 2021 22:56:45 -0600
Subject: [PATCH] Streamline invoker code (and its generation)

---
 pyopencl/invoker.py | 59 +++++++++++----------------------------------
 1 file changed, 14 insertions(+), 45 deletions(-)

diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py
index f4e97615..6125628b 100644
--- a/pyopencl/invoker.py
+++ b/pyopencl/invoker.py
@@ -28,6 +28,7 @@ import numpy as np
 from warnings import warn
 import pyopencl._cl as _cl
 from pytools.persistent_dict import WriteOncePersistentDict
+from pytools.py_codegen import Indentation
 from pyopencl.tools import _NumpyTypesKeyBuilder
 
 _PYPY = "__pypy__" in sys.builtin_module_names
@@ -53,8 +54,6 @@ del _size_t_char
 # {{{ individual arg handling
 
 def generate_buffer_arg_setter(gen, arg_idx, buf_var):
-    from pytools.py_codegen import Indentation
-
     if _PYPY:
         # https://github.com/numpy/numpy/issues/5381
         gen(f"if isinstance({buf_var}, np.generic):")
@@ -69,29 +68,6 @@ def generate_buffer_arg_setter(gen, arg_idx, buf_var):
         """
         .format(arg_idx=arg_idx, buf_var=buf_var))
 
-
-def generate_bytes_arg_setter(gen, arg_idx, buf_var):
-    gen("""
-        self._set_arg_buf({arg_idx}, {buf_var})
-        """
-        .format(arg_idx=arg_idx, buf_var=buf_var))
-
-
-def generate_generic_arg_handler(gen, arg_idx, arg_var):
-    from pytools.py_codegen import Indentation
-
-    gen("""
-        if isinstance({arg_var}, _KERNEL_ARG_CLASSES):
-            self.set_arg({arg_idx}, {arg_var})
-        elif {arg_var} is None:
-            self._set_arg_null({arg_idx})
-        """
-        .format(arg_idx=arg_idx, arg_var=arg_var))
-
-    gen("else:")
-    with Indentation(gen):
-        generate_buffer_arg_setter(gen, arg_idx, arg_var)
-
 # }}}
 
 
@@ -108,7 +84,7 @@ def generate_generic_arg_handling_body(num_args):
         gen(f"# process argument {i}")
         gen("")
         gen(f"current_arg = {i}")
-        generate_generic_arg_handler(gen, i, "arg%d" % i)
+        gen(f"self.set_arg({i}, arg{i})")
         gen("")
 
     return gen
@@ -141,7 +117,7 @@ def generate_specific_arg_handling_body(function_name,
         arg_var = "arg%d" % arg_idx
 
         if arg_dtype is None:
-            generate_generic_arg_handler(gen, cl_arg_idx, arg_var)
+            gen(f"self.set_arg({cl_arg_idx}, {arg_var})")
             cl_arg_idx += 1
             gen("")
             continue
@@ -149,7 +125,7 @@ def generate_specific_arg_handling_body(function_name,
         arg_dtype = np.dtype(arg_dtype)
 
         if arg_dtype.char == "V":
-            generate_generic_arg_handler(gen, cl_arg_idx, arg_var)
+            gen(f"self.set_arg({cl_arg_idx}, {arg_var})")
             cl_arg_idx += 1
 
         elif arg_dtype.kind == "c":
@@ -173,13 +149,13 @@ def generate_specific_arg_handling_body(function_name,
                 gen(
                         "buf = pack('{arg_char}', {arg_var}.real)"
                         .format(arg_char=arg_char, arg_var=arg_var))
-                generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
+                gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
                 cl_arg_idx += 1
                 gen("current_arg = current_arg + 1000")
                 gen(
                         "buf = pack('{arg_char}', {arg_var}.imag)"
                         .format(arg_char=arg_char, arg_var=arg_var))
-                generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
+                gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
                 cl_arg_idx += 1
 
             elif (work_around_arg_count_bug == "apple"
@@ -195,7 +171,7 @@ def generate_specific_arg_handling_body(function_name,
                         "buf = pack('{arg_char}{arg_char}', "
                         "{arg_var}.real, {arg_var}.imag)"
                         .format(arg_char=arg_char, arg_var=arg_var))
-                generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
+                gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
                 cl_arg_idx += 1
 
             fp_arg_count += 2
@@ -211,7 +187,7 @@ def generate_specific_arg_handling_body(function_name,
                     .format(
                         arg_char=arg_char,
                         arg_var=arg_var))
-            generate_bytes_arg_setter(gen, cl_arg_idx, "buf")
+            gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
             cl_arg_idx += 1
 
         gen("")
@@ -268,13 +244,6 @@ def wrap_in_error_handler(body, arg_names):
 # }}}
 
 
-def add_local_imports(gen):
-    gen("import numpy as np")
-    gen("import pyopencl._cl as _cl")
-    gen("from pyopencl import _KERNEL_ARG_CLASSES")
-    gen("")
-
-
 def _generate_enqueue_and_set_args_module(function_name,
         num_passed_args, num_cl_args,
         scalar_arg_dtypes,
@@ -292,12 +261,14 @@ def _generate_enqueue_and_set_args_module(function_name,
                 warn_about_arg_count_bug=warn_about_arg_count_bug,
                 work_around_arg_count_bug=work_around_arg_count_bug)
 
-    err_handler = wrap_in_error_handler(body, arg_names)
+    body = wrap_in_error_handler(body, arg_names)
 
     gen = PythonCodeGenerator()
 
     gen("from struct import pack")
     gen("from pyopencl import status_code")
+    gen("import numpy as np")
+    gen("import pyopencl._cl as _cl")
     gen("")
 
     # {{{ generate _enqueue
@@ -314,8 +285,7 @@ def _generate_enqueue_and_set_args_module(function_name,
                         "wait_for=None"])))
 
     with Indentation(gen):
-        add_local_imports(gen)
-        gen.extend(err_handler)
+        gen.extend(body)
 
         gen("""
             return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size,
@@ -332,8 +302,7 @@ def _generate_enqueue_and_set_args_module(function_name,
             % (", ".join(["self"] + arg_names)))
 
     with Indentation(gen):
-        add_local_imports(gen)
-        gen.extend(err_handler)
+        gen.extend(body)
 
     # }}}
 
@@ -341,7 +310,7 @@ def _generate_enqueue_and_set_args_module(function_name,
 
 
 invoker_cache = WriteOncePersistentDict(
-        "pyopencl-invoker-cache-v11",
+        "pyopencl-invoker-cache-v17",
         key_builder=_NumpyTypesKeyBuilder())
 
 
-- 
GitLab