diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index f7c0a92dec5dbb9e05fcbfdc7b826cf810374ac1..2b14a8efbc612bd534476b20d16627d250168135 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -649,7 +649,7 @@ void ${kernel_name}(${kernel_list_arg_decl} USER_ARG_DECL index_type n) %if name not in count_sharing: index_type plb_${name}_index; if (plb_${name}_start_index) - %if eliminate_empty_output_lists: + %if name in eliminate_empty_output_lists: plb_${name}_index = plb_${name}_start_index[ ${name}_compressed_indices[i] @@ -745,7 +745,7 @@ class ListOfListsBuilder: arg_decls, count_sharing=None, devices=None, name_prefix="plb_build_list", options=[], preamble="", debug=False, complex_kernel=False, - eliminate_empty_output_lists=False): + eliminate_empty_output_lists=[]): """ :arg context: A :class:`pyopencl.Context`. :arg list_names_and_dtypes: a list of `(name, dtype)` tuples @@ -760,6 +760,8 @@ class ListOfListsBuilder: :arg options: OpenCL compilation options for kernels using *generate_template*. :arg complex_kernel: If `True`, prevents vectorization on CPUs. + :arg eliminate_empty_output_lists: A Python list of list names + for which the empty output lists are eliminated. *generate_template* may use the following C macros/identifiers: @@ -793,6 +795,11 @@ class ListOfListsBuilder: and a second time, for a 'generation' stage where the lists are actually filled. A `generate` function that has side effects beyond calling `append` is therefore ill-formed. + + .. versionchanged:: 2018.1 + + Change *eliminate_empty_output_lists* argument type from `bool` to + `list`. """ if devices is None: @@ -819,7 +826,20 @@ class ListOfListsBuilder: self.debug = debug self.complex_kernel = complex_kernel + + if eliminate_empty_output_lists is True: + eliminate_empty_output_lists = \ + [name for name, _ in self.list_names_and_dtypes] + + if eliminate_empty_output_lists is False: + eliminate_empty_output_lists = [] + self.eliminate_empty_output_lists = eliminate_empty_output_lists + for list_name in self.eliminate_empty_output_lists: + if not any(list_name == name for name, _ in self.list_names_and_dtypes): + raise ValueError( + "invalid list name '%s' in eliminate_empty_output_lists" + % list_name) # {{{ kernel generators @@ -950,7 +970,7 @@ class ListOfListsBuilder: kernel_list_args.append( VectorArg(index_dtype, "plb_%s_start_index" % name)) - if self.eliminate_empty_output_lists: + if name in self.eliminate_empty_output_lists: kernel_list_args.append( VectorArg(index_dtype, "%s_compressed_indices" % name)) @@ -1030,9 +1050,9 @@ class ListOfListsBuilder: This implies that all lists are contiguous. - If the *eliminate_empty_output_lists* constructor argument is set to - True, *lists* has two additional attributes ``num_nonempty_lists`` and - ``nonempty_indices`` + If the list name is specified in *eliminate_empty_output_lists* + constructor argument, *lists* has two additional attributes + ``num_nonempty_lists`` and ``nonempty_indices`` * ``num_nonempty_lists`` for the number of nonempty lists. * ``nonempty_indices`` for the index of nonempty list in input objects. @@ -1091,7 +1111,7 @@ class ListOfListsBuilder: # The scan will turn the "counts" array into the "starts" array # in-place. - if self.eliminate_empty_output_lists: + if name in self.eliminate_empty_output_lists: result[name] = BuiltList(count=None, starts=counts, lists=None, num_nonempty_lists=None, nonempty_indices=None) @@ -1115,33 +1135,34 @@ class ListOfListsBuilder: *(tuple(count_list_args) + args + (n_objects,)), **dict(wait_for=wait_for)) - if self.eliminate_empty_output_lists: - compress_events = {} - for name, dtype in self.list_names_and_dtypes: - if name in omit_lists: - continue - if name in self.count_sharing: - continue - - compressed_counts = cl.array.empty( - queue, (n_objects + 1,), index_dtype, allocator=allocator) - info_record = result[name] - info_record.nonempty_indices = cl.array.empty( - queue, (n_objects + 1,), index_dtype, allocator=allocator) - info_record.num_nonempty_lists = cl.array.empty( - queue, (1,), index_dtype, allocator=allocator) - info_record.compressed_indices = cl.array.empty( - queue, (n_objects + 1,), index_dtype, allocator=allocator) - info_record.compressed_indices[0] = 0 - compress_events[name] = compress_kernel( - info_record.starts, - compressed_counts, - info_record.nonempty_indices, - info_record.compressed_indices, - info_record.num_nonempty_lists, - wait_for=[count_event] + info_record.compressed_indices.events) - - info_record.starts = compressed_counts + compress_events = {} + for name, dtype in self.list_names_and_dtypes: + if name in omit_lists: + continue + if name in self.count_sharing: + continue + if name not in self.eliminate_empty_output_lists: + continue + + compressed_counts = cl.array.empty( + queue, (n_objects + 1,), index_dtype, allocator=allocator) + info_record = result[name] + info_record.nonempty_indices = cl.array.empty( + queue, (n_objects + 1,), index_dtype, allocator=allocator) + info_record.num_nonempty_lists = cl.array.empty( + queue, (1,), index_dtype, allocator=allocator) + info_record.compressed_indices = cl.array.empty( + queue, (n_objects + 1,), index_dtype, allocator=allocator) + info_record.compressed_indices[0] = 0 + compress_events[name] = compress_kernel( + info_record.starts, + compressed_counts, + info_record.nonempty_indices, + info_record.compressed_indices, + info_record.num_nonempty_lists, + wait_for=[count_event] + info_record.compressed_indices.events) + + info_record.starts = compressed_counts # {{{ run scans @@ -1154,7 +1175,7 @@ class ListOfListsBuilder: continue info_record = result[name] - if self.eliminate_empty_output_lists: + if name in self.eliminate_empty_output_lists: compress_events[name].wait() num_nonempty_lists = info_record.num_nonempty_lists.get()[0] info_record.num_nonempty_lists = num_nonempty_lists @@ -1164,7 +1185,7 @@ class ListOfListsBuilder: info_record.starts[-1] = 0 starts_ary = info_record.starts - if self.eliminate_empty_output_lists: + if name in self.eliminate_empty_output_lists: evt = scan_kernel( starts_ary, size=info_record.num_nonempty_lists, @@ -1209,7 +1230,7 @@ class ListOfListsBuilder: if name not in self.count_sharing: write_list_args.append(info_record.starts.data) - if self.eliminate_empty_output_lists: + if name in self.eliminate_empty_output_lists: write_list_args.append(info_record.compressed_indices.data) # }}} diff --git a/test/test_algorithm.py b/test/test_algorithm.py index 0d956c1e3dc08acdae639fc032a76cc48c92f1c9..9bd15f862f5edcd08f900c03b418c55ac656da86 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -860,7 +860,7 @@ def test_list_builder_with_empty_elim(ctx_factory): builder = ListOfListsBuilder( context, - [("mylist1", np.int32), ("mylist2", np.int32)], + [("mylist1", np.int32), ("mylist2", np.int32), ("mylist3", np.int32)], """//CL// void generate(LIST_ARG_DECL USER_ARG_DECL index_type i) { @@ -870,12 +870,13 @@ def test_list_builder_with_empty_elim(ctx_factory): { APPEND_mylist1(j); APPEND_mylist2(j + 1); + APPEND_mylist3(j); } } } """, arg_decls=[], - eliminate_empty_output_lists=True) + eliminate_empty_output_lists=["mylist1", "mylist2"]) result, evt = builder(queue, 1000) @@ -887,6 +888,10 @@ def test_list_builder_with_empty_elim(ctx_factory): mylist2 = result["mylist2"] assert mylist2.count == 19900 assert (mylist2.lists.get()[:6] == [1, 1, 2, 1, 2, 3]).all() + mylist3 = result["mylist3"] + assert mylist3.count == 19900 + assert (mylist3.starts.get()[:10] == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]).all() + assert (mylist3.lists.get()[:6] == [0, 0, 1, 0, 1, 2]).all() def test_key_value_sorter(ctx_factory):