Skip to content
Snippets Groups Projects
Commit 75e460b7 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Streamline invoker code (and its generation)

parent 6f759230
No related branches found
No related tags found
1 merge request!134Speed up enqueue
...@@ -28,6 +28,7 @@ import numpy as np ...@@ -28,6 +28,7 @@ import numpy as np
from warnings import warn from warnings import warn
import pyopencl._cl as _cl import pyopencl._cl as _cl
from pytools.persistent_dict import WriteOncePersistentDict from pytools.persistent_dict import WriteOncePersistentDict
from pytools.py_codegen import Indentation
from pyopencl.tools import _NumpyTypesKeyBuilder from pyopencl.tools import _NumpyTypesKeyBuilder
_PYPY = "__pypy__" in sys.builtin_module_names _PYPY = "__pypy__" in sys.builtin_module_names
...@@ -53,8 +54,6 @@ del _size_t_char ...@@ -53,8 +54,6 @@ del _size_t_char
# {{{ individual arg handling # {{{ individual arg handling
def generate_buffer_arg_setter(gen, arg_idx, buf_var): def generate_buffer_arg_setter(gen, arg_idx, buf_var):
from pytools.py_codegen import Indentation
if _PYPY: if _PYPY:
# https://github.com/numpy/numpy/issues/5381 # https://github.com/numpy/numpy/issues/5381
gen(f"if isinstance({buf_var}, np.generic):") gen(f"if isinstance({buf_var}, np.generic):")
...@@ -69,29 +68,6 @@ def generate_buffer_arg_setter(gen, arg_idx, buf_var): ...@@ -69,29 +68,6 @@ def generate_buffer_arg_setter(gen, arg_idx, buf_var):
""" """
.format(arg_idx=arg_idx, buf_var=buf_var)) .format(arg_idx=arg_idx, buf_var=buf_var))
def generate_bytes_arg_setter(gen, arg_idx, buf_var):
gen("""
self._set_arg_buf({arg_idx}, {buf_var})
"""
.format(arg_idx=arg_idx, buf_var=buf_var))
def generate_generic_arg_handler(gen, arg_idx, arg_var):
from pytools.py_codegen import Indentation
gen("""
if isinstance({arg_var}, _KERNEL_ARG_CLASSES):
self.set_arg({arg_idx}, {arg_var})
elif {arg_var} is None:
self._set_arg_null({arg_idx})
"""
.format(arg_idx=arg_idx, arg_var=arg_var))
gen("else:")
with Indentation(gen):
generate_buffer_arg_setter(gen, arg_idx, arg_var)
# }}} # }}}
...@@ -108,7 +84,7 @@ def generate_generic_arg_handling_body(num_args): ...@@ -108,7 +84,7 @@ def generate_generic_arg_handling_body(num_args):
gen(f"# process argument {i}") gen(f"# process argument {i}")
gen("") gen("")
gen(f"current_arg = {i}") gen(f"current_arg = {i}")
generate_generic_arg_handler(gen, i, "arg%d" % i) gen(f"self.set_arg({i}, arg{i})")
gen("") gen("")
return gen return gen
...@@ -141,7 +117,7 @@ def generate_specific_arg_handling_body(function_name, ...@@ -141,7 +117,7 @@ def generate_specific_arg_handling_body(function_name,
arg_var = "arg%d" % arg_idx arg_var = "arg%d" % arg_idx
if arg_dtype is None: if arg_dtype is None:
generate_generic_arg_handler(gen, cl_arg_idx, arg_var) gen(f"self.set_arg({cl_arg_idx}, {arg_var})")
cl_arg_idx += 1 cl_arg_idx += 1
gen("") gen("")
continue continue
...@@ -149,7 +125,7 @@ def generate_specific_arg_handling_body(function_name, ...@@ -149,7 +125,7 @@ def generate_specific_arg_handling_body(function_name,
arg_dtype = np.dtype(arg_dtype) arg_dtype = np.dtype(arg_dtype)
if arg_dtype.char == "V": if arg_dtype.char == "V":
generate_generic_arg_handler(gen, cl_arg_idx, arg_var) gen(f"self.set_arg({cl_arg_idx}, {arg_var})")
cl_arg_idx += 1 cl_arg_idx += 1
elif arg_dtype.kind == "c": elif arg_dtype.kind == "c":
...@@ -173,13 +149,13 @@ def generate_specific_arg_handling_body(function_name, ...@@ -173,13 +149,13 @@ def generate_specific_arg_handling_body(function_name,
gen( gen(
"buf = pack('{arg_char}', {arg_var}.real)" "buf = pack('{arg_char}', {arg_var}.real)"
.format(arg_char=arg_char, arg_var=arg_var)) .format(arg_char=arg_char, arg_var=arg_var))
generate_bytes_arg_setter(gen, cl_arg_idx, "buf") gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
cl_arg_idx += 1 cl_arg_idx += 1
gen("current_arg = current_arg + 1000") gen("current_arg = current_arg + 1000")
gen( gen(
"buf = pack('{arg_char}', {arg_var}.imag)" "buf = pack('{arg_char}', {arg_var}.imag)"
.format(arg_char=arg_char, arg_var=arg_var)) .format(arg_char=arg_char, arg_var=arg_var))
generate_bytes_arg_setter(gen, cl_arg_idx, "buf") gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
cl_arg_idx += 1 cl_arg_idx += 1
elif (work_around_arg_count_bug == "apple" elif (work_around_arg_count_bug == "apple"
...@@ -195,7 +171,7 @@ def generate_specific_arg_handling_body(function_name, ...@@ -195,7 +171,7 @@ def generate_specific_arg_handling_body(function_name,
"buf = pack('{arg_char}{arg_char}', " "buf = pack('{arg_char}{arg_char}', "
"{arg_var}.real, {arg_var}.imag)" "{arg_var}.real, {arg_var}.imag)"
.format(arg_char=arg_char, arg_var=arg_var)) .format(arg_char=arg_char, arg_var=arg_var))
generate_bytes_arg_setter(gen, cl_arg_idx, "buf") gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
cl_arg_idx += 1 cl_arg_idx += 1
fp_arg_count += 2 fp_arg_count += 2
...@@ -211,7 +187,7 @@ def generate_specific_arg_handling_body(function_name, ...@@ -211,7 +187,7 @@ def generate_specific_arg_handling_body(function_name,
.format( .format(
arg_char=arg_char, arg_char=arg_char,
arg_var=arg_var)) arg_var=arg_var))
generate_bytes_arg_setter(gen, cl_arg_idx, "buf") gen(f"self._set_arg_buf({cl_arg_idx}, buf)")
cl_arg_idx += 1 cl_arg_idx += 1
gen("") gen("")
...@@ -268,13 +244,6 @@ def wrap_in_error_handler(body, arg_names): ...@@ -268,13 +244,6 @@ def wrap_in_error_handler(body, arg_names):
# }}} # }}}
def add_local_imports(gen):
gen("import numpy as np")
gen("import pyopencl._cl as _cl")
gen("from pyopencl import _KERNEL_ARG_CLASSES")
gen("")
def _generate_enqueue_and_set_args_module(function_name, def _generate_enqueue_and_set_args_module(function_name,
num_passed_args, num_cl_args, num_passed_args, num_cl_args,
scalar_arg_dtypes, scalar_arg_dtypes,
...@@ -292,12 +261,14 @@ def _generate_enqueue_and_set_args_module(function_name, ...@@ -292,12 +261,14 @@ def _generate_enqueue_and_set_args_module(function_name,
warn_about_arg_count_bug=warn_about_arg_count_bug, warn_about_arg_count_bug=warn_about_arg_count_bug,
work_around_arg_count_bug=work_around_arg_count_bug) work_around_arg_count_bug=work_around_arg_count_bug)
err_handler = wrap_in_error_handler(body, arg_names) body = wrap_in_error_handler(body, arg_names)
gen = PythonCodeGenerator() gen = PythonCodeGenerator()
gen("from struct import pack") gen("from struct import pack")
gen("from pyopencl import status_code") gen("from pyopencl import status_code")
gen("import numpy as np")
gen("import pyopencl._cl as _cl")
gen("") gen("")
# {{{ generate _enqueue # {{{ generate _enqueue
...@@ -314,8 +285,7 @@ def _generate_enqueue_and_set_args_module(function_name, ...@@ -314,8 +285,7 @@ def _generate_enqueue_and_set_args_module(function_name,
"wait_for=None"]))) "wait_for=None"])))
with Indentation(gen): with Indentation(gen):
add_local_imports(gen) gen.extend(body)
gen.extend(err_handler)
gen(""" gen("""
return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size, return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size,
...@@ -332,8 +302,7 @@ def _generate_enqueue_and_set_args_module(function_name, ...@@ -332,8 +302,7 @@ def _generate_enqueue_and_set_args_module(function_name,
% (", ".join(["self"] + arg_names))) % (", ".join(["self"] + arg_names)))
with Indentation(gen): with Indentation(gen):
add_local_imports(gen) gen.extend(body)
gen.extend(err_handler)
# }}} # }}}
...@@ -341,7 +310,7 @@ def _generate_enqueue_and_set_args_module(function_name, ...@@ -341,7 +310,7 @@ def _generate_enqueue_and_set_args_module(function_name,
invoker_cache = WriteOncePersistentDict( invoker_cache = WriteOncePersistentDict(
"pyopencl-invoker-cache-v11", "pyopencl-invoker-cache-v17",
key_builder=_NumpyTypesKeyBuilder()) key_builder=_NumpyTypesKeyBuilder())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment