From 56cd2bf152d44896e4255e18ad303b12b6817ef6 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Tue, 13 Nov 2018 21:30:17 -0600
Subject: [PATCH] ListOfListsBuilder: Add back support for MemoryObject
 arguments

---
 pyopencl/algorithm.py  | 26 ++++++++++++++++++++++----
 pyopencl/scan.py       | 11 +++++++++--
 pyopencl/tools.py      | 12 ------------
 test/test_algorithm.py | 25 +++++++++++++++++++++++++
 4 files changed, 56 insertions(+), 18 deletions(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 92730750..eb581f39 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -1111,8 +1111,26 @@ class ListOfListsBuilder:
         if self.eliminate_empty_output_lists:
             compress_kernel = self.get_compress_kernel(index_dtype)
 
-        from pyopencl.tools import expand_runtime_arg_list
-        args = expand_runtime_arg_list(self.arg_decls, args)
+        data_args = []
+        for i, (arg_descr, arg_val) in enumerate(zip(self.arg_decls, args)):
+            from pyopencl.tools import VectorArg
+            if isinstance(arg_descr, VectorArg):
+                from pyopencl import MemoryObject
+                if isinstance(arg_val, MemoryObject):
+                    data_args.append(arg_val)
+                    if arg_descr.with_offset:
+                        raise ValueError(
+                                "with_offset=True specified for argument '%d' "
+                                "but the argument is not an array." % i)
+
+                data_args.append(arg_val.base_data)
+                if arg_descr.with_offset:
+                    data_args.append(arg_val.offset)
+            else:
+                data_args.append(arg_val)
+
+        del args
+        data_args = tuple(data_args)
 
         # {{{ allocate memory for counts
 
@@ -1151,7 +1169,7 @@ class ListOfListsBuilder:
             gsize, lsize = splay(queue, n_objects)
 
         count_event = count_kernel(queue, gsize, lsize,
-                *(tuple(count_list_args) + args + (n_objects,)),
+                *(tuple(count_list_args) + data_args + (n_objects,)),
                 **dict(wait_for=wait_for))
 
         compress_events = {}
@@ -1257,7 +1275,7 @@ class ListOfListsBuilder:
         # }}}
 
         evt = write_kernel(queue, gsize, lsize,
-                *(tuple(write_list_args) + args + (n_objects,)),
+                *(tuple(write_list_args) + data_args + (n_objects,)),
                 **dict(wait_for=scan_events))
 
         return result, evt
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 6218d571..cef68c67 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -1480,8 +1480,15 @@ class GenericScanKernel(_GenericScanKernelBase):
             # We're done here. (But pretend to return an event.)
             return cl.enqueue_marker(queue, wait_for=wait_for)
 
-        from pyopencl.tools import expand_runtime_arg_list
-        data_args = list(expand_runtime_arg_list(self.parsed_args, args))
+        data_args = []
+        for arg_descr, arg_val in zip(self.parsed_args, args):
+            from pyopencl.tools import VectorArg
+            if isinstance(arg_descr, VectorArg):
+                data_args.append(arg_val.base_data)
+                if arg_descr.with_offset:
+                    data_args.append(arg_val.offset)
+            else:
+                data_args.append(arg_val)
 
         # }}}
 
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index b1dcf9f2..369fb272 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -400,18 +400,6 @@ def get_arg_offset_adjuster_code(arg_types):
 
     return "\n".join(result)
 
-
-def expand_runtime_arg_list(args, user_args):
-    data_args = []
-    for arg_descr, arg_val in zip(args, user_args):
-        if isinstance(arg_descr, VectorArg):
-            data_args.append(arg_val.base_data)
-            if arg_descr.with_offset:
-                data_args.append(arg_val.offset)
-        else:
-            data_args.append(arg_val)
-    return tuple(data_args)
-
 # }}}
 
 
diff --git a/test/test_algorithm.py b/test/test_algorithm.py
index 5c09b565..0360d6a3 100644
--- a/test/test_algorithm.py
+++ b/test/test_algorithm.py
@@ -880,6 +880,31 @@ def test_list_builder(ctx_factory):
     assert (inf.lists.get()[-6:] == [1, 2, 2, 3, 3, 3]).all()
 
 
+def test_list_builder_with_memoryobject(ctx_factory):
+    from pytest import importorskip
+    importorskip("mako")
+
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    from pyopencl.algorithm import ListOfListsBuilder
+    from pyopencl.tools import VectorArg
+    builder = ListOfListsBuilder(context, [("mylist", np.int32)], """//CL//
+            void generate(LIST_ARG_DECL USER_ARG_DECL index_type i)
+            {
+                APPEND_mylist(input_list[i]);
+            }
+            """, arg_decls=[VectorArg(float, "input_list")])
+
+    n = 10000
+    input_list = cl.array.zeros(queue, (n,), float)
+    result, evt = builder(queue, n, input_list.data)
+
+    inf = result["mylist"]
+    assert inf.count == n
+    assert (inf.lists.get() == 0).all()
+
+
 def test_list_builder_with_offset(ctx_factory):
     from pytest import importorskip
     importorskip("mako")
-- 
GitLab