From 75e460b7405aa59560c0e9f2bb2497894cd7a584 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 17 Jan 2021 22:56:45 -0600 Subject: [PATCH] Streamline invoker code (and its generation) --- pyopencl/invoker.py | 59 +++++++++++---------------------------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py index f4e97615..6125628b 100644 --- a/pyopencl/invoker.py +++ b/pyopencl/invoker.py @@ -28,6 +28,7 @@ import numpy as np from warnings import warn import pyopencl._cl as _cl from pytools.persistent_dict import WriteOncePersistentDict +from pytools.py_codegen import Indentation from pyopencl.tools import _NumpyTypesKeyBuilder _PYPY = "__pypy__" in sys.builtin_module_names @@ -53,8 +54,6 @@ del _size_t_char # {{{ individual arg handling def generate_buffer_arg_setter(gen, arg_idx, buf_var): - from pytools.py_codegen import Indentation - if _PYPY: # https://github.com/numpy/numpy/issues/5381 gen(f"if isinstance({buf_var}, np.generic):") @@ -69,29 +68,6 @@ def generate_buffer_arg_setter(gen, arg_idx, buf_var): """ .format(arg_idx=arg_idx, buf_var=buf_var)) - -def generate_bytes_arg_setter(gen, arg_idx, buf_var): - gen(""" - self._set_arg_buf({arg_idx}, {buf_var}) - """ - .format(arg_idx=arg_idx, buf_var=buf_var)) - - -def generate_generic_arg_handler(gen, arg_idx, arg_var): - from pytools.py_codegen import Indentation - - gen(""" - if isinstance({arg_var}, _KERNEL_ARG_CLASSES): - self.set_arg({arg_idx}, {arg_var}) - elif {arg_var} is None: - self._set_arg_null({arg_idx}) - """ - .format(arg_idx=arg_idx, arg_var=arg_var)) - - gen("else:") - with Indentation(gen): - generate_buffer_arg_setter(gen, arg_idx, arg_var) - # }}} @@ -108,7 +84,7 @@ def generate_generic_arg_handling_body(num_args): gen(f"# process argument {i}") gen("") gen(f"current_arg = {i}") - generate_generic_arg_handler(gen, i, "arg%d" % i) + gen(f"self.set_arg({i}, arg{i})") gen("") return gen @@ -141,7 +117,7 @@ def generate_specific_arg_handling_body(function_name, arg_var = "arg%d" % arg_idx if arg_dtype is None: - generate_generic_arg_handler(gen, cl_arg_idx, arg_var) + gen(f"self.set_arg({cl_arg_idx}, {arg_var})") cl_arg_idx += 1 gen("") continue @@ -149,7 +125,7 @@ def generate_specific_arg_handling_body(function_name, arg_dtype = np.dtype(arg_dtype) if arg_dtype.char == "V": - generate_generic_arg_handler(gen, cl_arg_idx, arg_var) + gen(f"self.set_arg({cl_arg_idx}, {arg_var})") cl_arg_idx += 1 elif arg_dtype.kind == "c": @@ -173,13 +149,13 @@ def generate_specific_arg_handling_body(function_name, gen( "buf = pack('{arg_char}', {arg_var}.real)" .format(arg_char=arg_char, arg_var=arg_var)) - generate_bytes_arg_setter(gen, cl_arg_idx, "buf") + gen(f"self._set_arg_buf({cl_arg_idx}, buf)") cl_arg_idx += 1 gen("current_arg = current_arg + 1000") gen( "buf = pack('{arg_char}', {arg_var}.imag)" .format(arg_char=arg_char, arg_var=arg_var)) - generate_bytes_arg_setter(gen, cl_arg_idx, "buf") + gen(f"self._set_arg_buf({cl_arg_idx}, buf)") cl_arg_idx += 1 elif (work_around_arg_count_bug == "apple" @@ -195,7 +171,7 @@ def generate_specific_arg_handling_body(function_name, "buf = pack('{arg_char}{arg_char}', " "{arg_var}.real, {arg_var}.imag)" .format(arg_char=arg_char, arg_var=arg_var)) - generate_bytes_arg_setter(gen, cl_arg_idx, "buf") + gen(f"self._set_arg_buf({cl_arg_idx}, buf)") cl_arg_idx += 1 fp_arg_count += 2 @@ -211,7 +187,7 @@ def generate_specific_arg_handling_body(function_name, .format( arg_char=arg_char, arg_var=arg_var)) - generate_bytes_arg_setter(gen, cl_arg_idx, "buf") + gen(f"self._set_arg_buf({cl_arg_idx}, buf)") cl_arg_idx += 1 gen("") @@ -268,13 +244,6 @@ def wrap_in_error_handler(body, arg_names): # }}} -def add_local_imports(gen): - gen("import numpy as np") - gen("import pyopencl._cl as _cl") - gen("from pyopencl import _KERNEL_ARG_CLASSES") - gen("") - - def _generate_enqueue_and_set_args_module(function_name, num_passed_args, num_cl_args, scalar_arg_dtypes, @@ -292,12 +261,14 @@ def _generate_enqueue_and_set_args_module(function_name, warn_about_arg_count_bug=warn_about_arg_count_bug, work_around_arg_count_bug=work_around_arg_count_bug) - err_handler = wrap_in_error_handler(body, arg_names) + body = wrap_in_error_handler(body, arg_names) gen = PythonCodeGenerator() gen("from struct import pack") gen("from pyopencl import status_code") + gen("import numpy as np") + gen("import pyopencl._cl as _cl") gen("") # {{{ generate _enqueue @@ -314,8 +285,7 @@ def _generate_enqueue_and_set_args_module(function_name, "wait_for=None"]))) with Indentation(gen): - add_local_imports(gen) - gen.extend(err_handler) + gen.extend(body) gen(""" return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size, @@ -332,8 +302,7 @@ def _generate_enqueue_and_set_args_module(function_name, % (", ".join(["self"] + arg_names))) with Indentation(gen): - add_local_imports(gen) - gen.extend(err_handler) + gen.extend(body) # }}} @@ -341,7 +310,7 @@ def _generate_enqueue_and_set_args_module(function_name, invoker_cache = WriteOncePersistentDict( - "pyopencl-invoker-cache-v11", + "pyopencl-invoker-cache-v17", key_builder=_NumpyTypesKeyBuilder()) -- GitLab