diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 8d0c309b08b8df4cda7e13c097441ef272449a02..df7f658688ed4de2ab8d6b9131be733820a1e272 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -292,6 +292,12 @@ class PyOpenCLTarget(OpenCLTarget): super().__init__( atomics_flavor=atomics_flavor) + import pyopencl.version + if pyopencl.version.VERSION < (2021, 1): + raise RuntimeError("The version of loopy you have installed " + "generates invoker code that requires PyOpenCL 2021.1 " + "or newer.") + self.device = device self.pyopencl_module_name = pyopencl_module_name @@ -484,12 +490,25 @@ def generate_value_arg_setup(kernel, devices, implemented_data_info): fp_arg_count = 0 - from genpy import ( - Comment, Line, If, Raise, Assign, Statement as S, Suite) + from genpy import If, Raise, Statement as S, Suite result = [] gen = result.append + buf_indices_and_args = [] + buf_pack_indices_and_args = [] + + from pyopencl.invoker import BUF_PACK_TYPECHARS + + def add_buf_arg(arg_idx, typechar, expr_str): + if typechar in BUF_PACK_TYPECHARS: + buf_pack_indices_and_args.append(arg_idx) + buf_pack_indices_and_args.append(repr(typechar.encode())) + buf_pack_indices_and_args.append(expr_str) + else: + buf_indices_and_args.append(arg_idx) + buf_indices_and_args.append(f"pack('{typechar}', {expr_str})") + for arg_idx, idi in enumerate(implemented_data_info): arg_idx_to_cl_arg_idx[arg_idx] = cl_arg_idx @@ -501,16 +520,15 @@ def generate_value_arg_setup(kernel, devices, implemented_data_info): continue - gen(Comment("{{{ process %s" % idi.name)) - gen(Line()) - if not options.skip_arg_checks: gen(If("%s is None" % idi.name, Raise('RuntimeError("input argument \'{name}\' ' 'must be supplied")'.format(name=idi.name)))) if idi.dtype.is_composite(): - gen(S("_lpy_knl.set_arg(%d, %s)" % (cl_arg_idx, idi.name))) + buf_indices_and_args.append(cl_arg_idx) + buf_indices_and_args.append(f"{idi.name}") + cl_arg_idx += 1 elif idi.dtype.is_complex(): @@ -535,32 +553,16 @@ def generate_value_arg_setup(kernel, devices, implemented_data_info): if (work_around_arg_count_bug and dtype.numpy_dtype == np.complex128 and fp_arg_count + 2 <= 8): - gen(Assign( - "_lpy_buf", - "_lpy_pack('{arg_char}', {arg_var}.real)" - .format(arg_char=arg_char, arg_var=idi.name))) - gen(S( - "_lpy_knl.set_arg({cl_arg_idx}, _lpy_buf)" - .format(cl_arg_idx=cl_arg_idx))) + add_buf_arg(cl_arg_idx, arg_char, f"{idi.name}.real") cl_arg_idx += 1 - gen(Assign( - "_lpy_buf", - "_lpy_pack('{arg_char}', {arg_var}.imag)" - .format(arg_char=arg_char, arg_var=idi.name))) - gen(S( - "_lpy_knl.set_arg({cl_arg_idx}, _lpy_buf)" - .format(cl_arg_idx=cl_arg_idx))) + add_buf_arg(cl_arg_idx, arg_char, f"{idi.name}.imag") cl_arg_idx += 1 else: - gen(Assign( - "_lpy_buf", - "_lpy_pack('{arg_char}{arg_char}', " - "{arg_var}.real, {arg_var}.imag)" - .format(arg_char=arg_char, arg_var=idi.name))) - gen(S( - "_lpy_knl.set_arg({cl_arg_idx}, _lpy_buf)" - .format(cl_arg_idx=cl_arg_idx))) + buf_indices_and_args.append(cl_arg_idx) + buf_indices_and_args.append( + f"_lpy_pack('{arg_char}{arg_char}', " + f"{idi.name}.real, {idi.name}.imag)") cl_arg_idx += 1 fp_arg_count += 2 @@ -569,20 +571,22 @@ def generate_value_arg_setup(kernel, devices, implemented_data_info): if idi.dtype.dtype.kind == "f": fp_arg_count += 1 - gen(S( - "_lpy_knl._set_arg_buf(%d, _lpy_pack('%s', %s))" - % (cl_arg_idx, idi.dtype.dtype.char, idi.name))) - + add_buf_arg(cl_arg_idx, idi.dtype.dtype.char, idi.name) cl_arg_idx += 1 else: raise LoopyError("do not know how to pass argument of type '%s'" % idi.dtype) - gen(Line()) - - gen(Comment("}}}")) - gen(Line()) + for arg_kind, args_and_indices, entry_length in [ + ("_buf", buf_indices_and_args, 2), + ("_buf_pack", buf_pack_indices_and_args, 3), + ]: + assert len(args_and_indices) % entry_length == 0 + if args_and_indices: + gen(S(f"_lpy_knl._set_arg{arg_kind}_multi(" + f"({', '.join(str(i) for i in args_and_indices)},), " + ")")) return Suite(result), arg_idx_to_cl_arg_idx, cl_arg_idx @@ -596,13 +600,18 @@ def generate_array_arg_setup(kernel, implemented_data_info, arg_idx_to_cl_arg_id result = [] gen = result.append + cl_indices_and_args = [] for arg_idx, arg in enumerate(implemented_data_info): - if not issubclass(arg.arg_class, ArrayBase): - continue + if issubclass(arg.arg_class, ArrayBase): + cl_indices_and_args.append(arg_idx_to_cl_arg_idx[arg_idx]) + cl_indices_and_args.append(arg.name) - cl_arg_idx = arg_idx_to_cl_arg_idx[arg_idx] + if cl_indices_and_args: + assert len(cl_indices_and_args) % 2 == 0 - gen(S("_lpy_knl.set_arg(%d, %s)" % (cl_arg_idx, arg.name))) + gen(S(f"_lpy_knl._set_arg_multi(" + f"({', '.join(str(i) for i in cl_indices_and_args)},)" + ")")) return Suite(result)