diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 5083b095e2d881d50b8de45d7e17b2e8dda4c415..73d87154ea4d8a5b17ba09b6c2c1d404a0dd53cc 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -5,7 +5,8 @@ from __future__ import absolute_import from six.moves import range from six.moves import zip -__copyright__ = """Copyright 2011-2012 Andreas Kloeckner""" +__copyright__ = """Copyright 2011-2012 Andreas Kloeckner \ + Copyright 2017 Hao Gao""" __license__ = """ Permission is hereby granted, free of charge, to any person @@ -648,7 +649,14 @@ void ${kernel_name}(${kernel_list_arg_decl} USER_ARG_DECL index_type n) %if name not in count_sharing: index_type plb_${name}_index; if (plb_${name}_start_index) - plb_${name}_index = plb_${name}_start_index[i]; + %if eliminate_empty_output_lists: + plb_${name}_index = + plb_${name}_start_index[ + ${name}_compressed_indices[i] + ]; + %else: + plb_${name}_index = plb_${name}_start_index[i]; + %endif else plb_${name}_index = 0; %endif @@ -736,7 +744,8 @@ class ListOfListsBuilder: def __init__(self, context, list_names_and_dtypes, generate_template, arg_decls, count_sharing=None, devices=None, name_prefix="plb_build_list", options=[], preamble="", - debug=False, complex_kernel=False): + debug=False, complex_kernel=False, + eliminate_empty_output_lists=False): """ :arg context: A :class:`pyopencl.Context`. :arg list_names_and_dtypes: a list of `(name, dtype)` tuples @@ -810,6 +819,7 @@ class ListOfListsBuilder: self.debug = debug self.complex_kernel = complex_kernel + self.eliminate_empty_output_lists = eliminate_empty_output_lists # {{{ kernel generators @@ -824,6 +834,36 @@ class ListOfListsBuilder: output_statement="ary[i+1] = item;", devices=self.devices) + @memoize_method + def get_compress_kernel(self, index_dtype): + arguments = """ + __global ${index_t} *count, + __global ${index_t} *nonempty_indices, + __global ${index_t} *compressed_indices, + __global ${index_t} *num_non_empty_list + """ + from sys import version_info + if (version_info > (3, 0)): + arguments = Template(arguments) + else: + arguments = Template(arguments, disable_unicode=True) + + from pyopencl.scan import GenericScanKernel + return GenericScanKernel( + self.context, index_dtype, + arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)), + input_expr="count[i] == 0 ? 0 : 1", + scan_expr="a+b", neutral="0", + output_statement=""" + compressed_indices[i + 1] = item; + if (prev_item != item) { + nonempty_indices[item - 1] = i; + count[item - 1] = count[i]; + } + if (i + 1 == N) *num_non_empty_list = item; + """, + devices=self.devices) + def do_not_vectorize(self): from pytools import any return (self.complex_kernel @@ -858,6 +898,7 @@ class ListOfListsBuilder: self.context.devices), debug=self.debug, do_not_vectorize=self.do_not_vectorize(), + eliminate_empty_output_lists=self.eliminate_empty_output_lists, kernel_list_arg_decl=_get_arg_decl(kernel_list_args), kernel_list_arg_values=_get_arg_list(user_list_args, prefix="&"), @@ -908,6 +949,10 @@ class ListOfListsBuilder: kernel_list_args.append( VectorArg(index_dtype, "plb_%s_start_index" % name)) + if self.eliminate_empty_output_lists: + kernel_list_args.append( + VectorArg(index_dtype, "%s_compressed_indices" % name)) + index_name = "plb_%s_index" % name user_list_args.append(OtherArg("%s *%s" % ( index_ctype, index_name), index_name)) @@ -924,6 +969,7 @@ class ListOfListsBuilder: self.context.devices), debug=self.debug, do_not_vectorize=self.do_not_vectorize(), + eliminate_empty_output_lists=self.eliminate_empty_output_lists, kernel_list_arg_decl=_get_arg_decl(kernel_list_args), kernel_list_arg_values=kernel_list_arg_values, @@ -983,7 +1029,18 @@ class ListOfListsBuilder: This implies that all lists are contiguous. - *event* is a :class:`pyopencl.Event` for dependency management. + If the *eliminate_empty_output_lists* constructor argument is set to + True, *lists* has two additional attributes ``num_nonempty_lists`` and + ``nonempty_indices`` + + * ``num_nonempty_lists`` for the number of nonempty lists. + * ``nonempty_indices`` for the index of nonempty list in input objects. + + In this case, `starts` has `num_nonempty_lists` + 1 entries. The *i*'s + entry is the start of the *i*'th nonempty list, which is generated by + the object with index *nonempty_indices[i]*. + + *event* is a :class:`pyopencl.Event` for dependency management. .. versionchanged:: 2016.2 @@ -1014,6 +1071,8 @@ class ListOfListsBuilder: count_kernel = self.get_count_kernel(index_dtype) write_kernel = self.get_write_kernel(index_dtype) scan_kernel = self.get_scan_kernel(index_dtype) + if self.eliminate_empty_output_lists: + compress_kernel = self.get_compress_kernel(index_dtype) # {{{ allocate memory for counts @@ -1031,7 +1090,12 @@ class ListOfListsBuilder: # The scan will turn the "counts" array into the "starts" array # in-place. - result[name] = BuiltList(starts=counts) + if self.eliminate_empty_output_lists: + result[name] = BuiltList(count=None, starts=counts, lists=None, + num_nonempty_lists=None, + nonempty_indices=None) + else: + result[name] = BuiltList(count=None, starts=counts, lists=None) count_list_args.append(counts.data) # }}} @@ -1050,6 +1114,29 @@ class ListOfListsBuilder: *(tuple(count_list_args) + args + (n_objects,)), **dict(wait_for=wait_for)) + if self.eliminate_empty_output_lists: + compress_events = {} + for name, dtype in self.list_names_and_dtypes: + if name in omit_lists: + continue + if name in self.count_sharing: + continue + + info_record = result[name] + info_record.nonempty_indices = cl.array.empty( + queue, (n_objects + 1,), index_dtype, allocator=allocator) + info_record.num_nonempty_lists = cl.array.empty( + queue, (1,), index_dtype, allocator=allocator) + info_record.compressed_indices = cl.array.empty( + queue, (n_objects + 1,), index_dtype, allocator=allocator) + info_record.compressed_indices[0] = 0 + compress_events[name] = compress_kernel( + info_record.starts, + info_record.nonempty_indices, + info_record.compressed_indices, + info_record.num_nonempty_lists, + wait_for=[count_event] + info_record.compressed_indices.events) + # {{{ run scans scan_events = [] @@ -1061,9 +1148,24 @@ class ListOfListsBuilder: continue info_record = result[name] + if self.eliminate_empty_output_lists: + compress_events[name].wait() + num_nonempty_lists = info_record.num_nonempty_lists.get()[0] + info_record.num_nonempty_lists = num_nonempty_lists + info_record.starts = info_record.starts[:num_nonempty_lists + 1] + info_record.nonempty_indices = \ + info_record.nonempty_indices[:num_nonempty_lists] + info_record.starts[-1] = 0 + starts_ary = info_record.starts - evt = scan_kernel(starts_ary, wait_for=[count_event], - size=n_objects) + if self.eliminate_empty_output_lists: + evt = scan_kernel( + starts_ary, + size=info_record.num_nonempty_lists, + wait_for=starts_ary.events) + else: + evt = scan_kernel(starts_ary, wait_for=[count_event], + size=n_objects) starts_ary.setitem(0, 0, queue=queue, wait_for=[evt]) scan_events.extend(starts_ary.events) @@ -1101,6 +1203,9 @@ class ListOfListsBuilder: if name not in self.count_sharing: write_list_args.append(info_record.starts.data) + if self.eliminate_empty_output_lists: + write_list_args.append(info_record.compressed_indices.data) + # }}} evt = write_kernel(queue, gsize, lsize, diff --git a/test/test_algorithm.py b/test/test_algorithm.py index 2e1537dfa2db1de92c00185fe094fdce37704e0e..ee32278bbf1c48142eb6dbf78963c03b3a9e313e 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -847,6 +847,36 @@ def test_list_builder(ctx_factory): assert inf.count == 3000 assert (inf.lists.get()[-6:] == [1, 2, 2, 3, 3, 3]).all() + builder = ListOfListsBuilder( + context, + [("mylist1", np.int32), ("mylist2", np.int32)], + """//CL// + void generate(LIST_ARG_DECL USER_ARG_DECL index_type i) + { + if (i % 5 == 0) + { + for (int j = 0; j < i / 5; ++j) + { + APPEND_mylist1(j); + APPEND_mylist2(j + 1); + } + } + } + """, + arg_decls=[], + eliminate_empty_output_lists=True) + + result, evt = builder(queue, 1000) + + mylist1 = result["mylist1"] + assert mylist1.count == 19900 + assert (mylist1.starts.get()[:5] == [0, 1, 3, 6, 10]).all() + assert (mylist1.nonempty_indices.get()[:5] == [5, 10, 15, 20, 25]).all() + assert (mylist1.lists.get()[:6] == [0, 0, 1, 0, 1, 2]).all() + mylist2 = result["mylist2"] + assert mylist2.count == 19900 + assert (mylist2.lists.get()[:6] == [1, 1, 2, 1, 2, 3]).all() + def test_key_value_sorter(ctx_factory): from pytest import importorskip