From 66d0d25cc43773d08cbb42f30f33e267172f4627 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 17 Jan 2021 22:11:57 -0600 Subject: [PATCH] Make a kernel-specific class to override __call__ and avoid an indirect call on kernel enqueue --- pyopencl/__init__.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index 6e0268dc..8a21d804 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -847,12 +847,22 @@ def _add_functionality(): # }}} from pyopencl.invoker import generate_enqueue_and_set_args - self._enqueue, self._set_args = generate_enqueue_and_set_args( - self.function_name, - len(scalar_arg_dtypes), self.num_args, - self._scalar_arg_dtypes, - warn_about_arg_count_bug=warn_about_arg_count_bug, - work_around_arg_count_bug=work_around_arg_count_bug) + enqueue, set_args = \ + generate_enqueue_and_set_args( + self.function_name, + len(scalar_arg_dtypes), self.num_args, + self._scalar_arg_dtypes, + warn_about_arg_count_bug=warn_about_arg_count_bug, + work_around_arg_count_bug=work_around_arg_count_bug) + + # Make ourselves a kernel-specific class, so that we're able to override + # __call__. Inspired by https://stackoverflow.com/a/38541437 + class KernelWithOverriddenCall(type(self)): + pass + + self.__class__ = KernelWithOverriddenCall + KernelWithOverriddenCall.__call__ = enqueue + KernelWithOverriddenCall._set_args = set_args def kernel_get_work_group_info(self, param, device): try: @@ -871,6 +881,9 @@ def _add_functionality(): def kernel_call(self, queue, global_size, local_size, *args, **kwargs): # __call__ can't be overridden directly, so we need this # trampoline hack. + + # Note: This is only used for the generic __call__, before + # kernel_set_scalar_arg_dtypes is called. return self._enqueue(self, queue, global_size, local_size, *args, **kwargs) def kernel_capture_call(self, filename, queue, global_size, local_size, -- GitLab