From 82ff5b92de50138d7d8f7201ca90815a1ee29d37 Mon Sep 17 00:00:00 2001
From: Hao Gao <gaohao95@gmail.com>
Date: Tue, 14 Nov 2017 23:03:49 -0600
Subject: [PATCH] Add testing for sparse list

---
 pyopencl/algorithm.py  |  4 +++-
 test/test_algorithm.py | 30 ++++++++++++++++++++++++++++++
 2 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 5083b095..5f8287cf 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -736,7 +736,8 @@ class ListOfListsBuilder:
     def __init__(self, context, list_names_and_dtypes, generate_template,
             arg_decls, count_sharing=None, devices=None,
             name_prefix="plb_build_list", options=[], preamble="",
-            debug=False, complex_kernel=False):
+            debug=False, complex_kernel=False,
+            eliminate_empty_output_lists=False):
         """
         :arg context: A :class:`pyopencl.Context`.
         :arg list_names_and_dtypes: a list of `(name, dtype)` tuples
@@ -810,6 +811,7 @@ class ListOfListsBuilder:
         self.debug = debug
 
         self.complex_kernel = complex_kernel
+        self.eliminate_empty_output_lists = eliminate_empty_output_lists
 
     # {{{ kernel generators
 
diff --git a/test/test_algorithm.py b/test/test_algorithm.py
index 2e1537df..ab8fa6f7 100644
--- a/test/test_algorithm.py
+++ b/test/test_algorithm.py
@@ -847,6 +847,36 @@ def test_list_builder(ctx_factory):
     assert inf.count == 3000
     assert (inf.lists.get()[-6:] == [1, 2, 2, 3, 3, 3]).all()
 
+    builder = ListOfListsBuilder(
+        context,
+        [("mylist1", np.int32), ("mylist2", np.int32)],
+        """//CL//
+            void generate(LIST_ARG_DECL USER_ARG_DECL index_type i)
+            {
+                if (i % 5 == 0)
+                {
+                    for (int j = 0; j < 10; ++j)
+                    {
+                        APPEND_mylist1(j);
+                        APPEND_mylist2(1);
+                    }
+                }
+            }
+        """,
+        arg_decls=[],
+        eliminate_empty_output_lists=True)
+
+    result, evt = builder(queue, 1000)
+
+    mylist1 = result["mylist1"]
+    assert mylist1.count == 2000
+    assert (mylist1.starts.get()[:5] == [0, 10, 20, 30, 40]).all()
+    assert (mylist1.indices.get()[:5] == [0, 5, 10, 15, 20]).all()
+    assert (mylist1.lists.get()[:5] == [0, 1, 2, 3, 4]).all()
+    mylist2 = result["mylist2"]
+    assert mylist2.count == 2000
+    assert (mylist2.lists.get()[:5] == [1, 1, 1, 1, 1]).all()
+
 
 def test_key_value_sorter(ctx_factory):
     from pytest import importorskip
-- 
GitLab