From 56cd2bf152d44896e4255e18ad303b12b6817ef6 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Tue, 13 Nov 2018 21:30:17 -0600 Subject: [PATCH] ListOfListsBuilder: Add back support for MemoryObject arguments --- pyopencl/algorithm.py | 26 ++++++++++++++++++++++---- pyopencl/scan.py | 11 +++++++++-- pyopencl/tools.py | 12 ------------ test/test_algorithm.py | 25 +++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 92730750..eb581f39 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -1111,8 +1111,26 @@ class ListOfListsBuilder: if self.eliminate_empty_output_lists: compress_kernel = self.get_compress_kernel(index_dtype) - from pyopencl.tools import expand_runtime_arg_list - args = expand_runtime_arg_list(self.arg_decls, args) + data_args = [] + for i, (arg_descr, arg_val) in enumerate(zip(self.arg_decls, args)): + from pyopencl.tools import VectorArg + if isinstance(arg_descr, VectorArg): + from pyopencl import MemoryObject + if isinstance(arg_val, MemoryObject): + data_args.append(arg_val) + if arg_descr.with_offset: + raise ValueError( + "with_offset=True specified for argument '%d' " + "but the argument is not an array." % i) + + data_args.append(arg_val.base_data) + if arg_descr.with_offset: + data_args.append(arg_val.offset) + else: + data_args.append(arg_val) + + del args + data_args = tuple(data_args) # {{{ allocate memory for counts @@ -1151,7 +1169,7 @@ class ListOfListsBuilder: gsize, lsize = splay(queue, n_objects) count_event = count_kernel(queue, gsize, lsize, - *(tuple(count_list_args) + args + (n_objects,)), + *(tuple(count_list_args) + data_args + (n_objects,)), **dict(wait_for=wait_for)) compress_events = {} @@ -1257,7 +1275,7 @@ class ListOfListsBuilder: # }}} evt = write_kernel(queue, gsize, lsize, - *(tuple(write_list_args) + args + (n_objects,)), + *(tuple(write_list_args) + data_args + (n_objects,)), **dict(wait_for=scan_events)) return result, evt diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 6218d571..cef68c67 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -1480,8 +1480,15 @@ class GenericScanKernel(_GenericScanKernelBase): # We're done here. (But pretend to return an event.) return cl.enqueue_marker(queue, wait_for=wait_for) - from pyopencl.tools import expand_runtime_arg_list - data_args = list(expand_runtime_arg_list(self.parsed_args, args)) + data_args = [] + for arg_descr, arg_val in zip(self.parsed_args, args): + from pyopencl.tools import VectorArg + if isinstance(arg_descr, VectorArg): + data_args.append(arg_val.base_data) + if arg_descr.with_offset: + data_args.append(arg_val.offset) + else: + data_args.append(arg_val) # }}} diff --git a/pyopencl/tools.py b/pyopencl/tools.py index b1dcf9f2..369fb272 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -400,18 +400,6 @@ def get_arg_offset_adjuster_code(arg_types): return "\n".join(result) - -def expand_runtime_arg_list(args, user_args): - data_args = [] - for arg_descr, arg_val in zip(args, user_args): - if isinstance(arg_descr, VectorArg): - data_args.append(arg_val.base_data) - if arg_descr.with_offset: - data_args.append(arg_val.offset) - else: - data_args.append(arg_val) - return tuple(data_args) - # }}} diff --git a/test/test_algorithm.py b/test/test_algorithm.py index 5c09b565..0360d6a3 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -880,6 +880,31 @@ def test_list_builder(ctx_factory): assert (inf.lists.get()[-6:] == [1, 2, 2, 3, 3, 3]).all() +def test_list_builder_with_memoryobject(ctx_factory): + from pytest import importorskip + importorskip("mako") + + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.algorithm import ListOfListsBuilder + from pyopencl.tools import VectorArg + builder = ListOfListsBuilder(context, [("mylist", np.int32)], """//CL// + void generate(LIST_ARG_DECL USER_ARG_DECL index_type i) + { + APPEND_mylist(input_list[i]); + } + """, arg_decls=[VectorArg(float, "input_list")]) + + n = 10000 + input_list = cl.array.zeros(queue, (n,), float) + result, evt = builder(queue, n, input_list.data) + + inf = result["mylist"] + assert inf.count == n + assert (inf.lists.get() == 0).all() + + def test_list_builder_with_offset(ctx_factory): from pytest import importorskip importorskip("mako") -- GitLab