From 82ff5b92de50138d7d8f7201ca90815a1ee29d37 Mon Sep 17 00:00:00 2001 From: Hao Gao <gaohao95@gmail.com> Date: Tue, 14 Nov 2017 23:03:49 -0600 Subject: [PATCH] Add testing for sparse list --- pyopencl/algorithm.py | 4 +++- test/test_algorithm.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 5083b095..5f8287cf 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -736,7 +736,8 @@ class ListOfListsBuilder: def __init__(self, context, list_names_and_dtypes, generate_template, arg_decls, count_sharing=None, devices=None, name_prefix="plb_build_list", options=[], preamble="", - debug=False, complex_kernel=False): + debug=False, complex_kernel=False, + eliminate_empty_output_lists=False): """ :arg context: A :class:`pyopencl.Context`. :arg list_names_and_dtypes: a list of `(name, dtype)` tuples @@ -810,6 +811,7 @@ class ListOfListsBuilder: self.debug = debug self.complex_kernel = complex_kernel + self.eliminate_empty_output_lists = eliminate_empty_output_lists # {{{ kernel generators diff --git a/test/test_algorithm.py b/test/test_algorithm.py index 2e1537df..ab8fa6f7 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -847,6 +847,36 @@ def test_list_builder(ctx_factory): assert inf.count == 3000 assert (inf.lists.get()[-6:] == [1, 2, 2, 3, 3, 3]).all() + builder = ListOfListsBuilder( + context, + [("mylist1", np.int32), ("mylist2", np.int32)], + """//CL// + void generate(LIST_ARG_DECL USER_ARG_DECL index_type i) + { + if (i % 5 == 0) + { + for (int j = 0; j < 10; ++j) + { + APPEND_mylist1(j); + APPEND_mylist2(1); + } + } + } + """, + arg_decls=[], + eliminate_empty_output_lists=True) + + result, evt = builder(queue, 1000) + + mylist1 = result["mylist1"] + assert mylist1.count == 2000 + assert (mylist1.starts.get()[:5] == [0, 10, 20, 30, 40]).all() + assert (mylist1.indices.get()[:5] == [0, 5, 10, 15, 20]).all() + assert (mylist1.lists.get()[:5] == [0, 1, 2, 3, 4]).all() + mylist2 = result["mylist2"] + assert mylist2.count == 2000 + assert (mylist2.lists.get()[:5] == [1, 1, 1, 1, 1]).all() + def test_key_value_sorter(ctx_factory): from pytest import importorskip -- GitLab