From 66d0d25cc43773d08cbb42f30f33e267172f4627 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 17 Jan 2021 22:11:57 -0600
Subject: [PATCH] Make a kernel-specific class to override __call__ and avoid
 an indirect call on kernel enqueue

---
 pyopencl/__init__.py | 25 +++++++++++++++++++------
 1 file changed, 19 insertions(+), 6 deletions(-)

diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py
index 6e0268dc..8a21d804 100644
--- a/pyopencl/__init__.py
+++ b/pyopencl/__init__.py
@@ -847,12 +847,22 @@ def _add_functionality():
         # }}}
 
         from pyopencl.invoker import generate_enqueue_and_set_args
-        self._enqueue, self._set_args = generate_enqueue_and_set_args(
-                self.function_name,
-                len(scalar_arg_dtypes), self.num_args,
-                self._scalar_arg_dtypes,
-                warn_about_arg_count_bug=warn_about_arg_count_bug,
-                work_around_arg_count_bug=work_around_arg_count_bug)
+        enqueue, set_args = \
+                generate_enqueue_and_set_args(
+                        self.function_name,
+                        len(scalar_arg_dtypes), self.num_args,
+                        self._scalar_arg_dtypes,
+                        warn_about_arg_count_bug=warn_about_arg_count_bug,
+                        work_around_arg_count_bug=work_around_arg_count_bug)
+
+        # Make ourselves a kernel-specific class, so that we're able to override
+        # __call__. Inspired by https://stackoverflow.com/a/38541437
+        class KernelWithOverriddenCall(type(self)):
+            pass
+
+        self.__class__ = KernelWithOverriddenCall
+        KernelWithOverriddenCall.__call__ = enqueue
+        KernelWithOverriddenCall._set_args = set_args
 
     def kernel_get_work_group_info(self, param, device):
         try:
@@ -871,6 +881,9 @@ def _add_functionality():
     def kernel_call(self, queue, global_size, local_size, *args, **kwargs):
         # __call__ can't be overridden directly, so we need this
         # trampoline hack.
+
+        # Note: This is only used for the generic __call__, before
+        # kernel_set_scalar_arg_dtypes is called.
         return self._enqueue(self, queue, global_size, local_size, *args, **kwargs)
 
     def kernel_capture_call(self, filename, queue, global_size, local_size,
-- 
GitLab