Skip to content
Snippets Groups Projects

Make list compression arg a list instead of a bool

Merged Hao Gao requested to merge list-compression-arg into master
All threads resolved!
2 files
+ 53
39
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 46
37
@@ -649,7 +649,7 @@ void ${kernel_name}(${kernel_list_arg_decl} USER_ARG_DECL index_type n)
%if name not in count_sharing:
index_type plb_${name}_index;
if (plb_${name}_start_index)
%if eliminate_empty_output_lists:
%if name in eliminate_empty_output_lists:
plb_${name}_index =
plb_${name}_start_index[
${name}_compressed_indices[i]
@@ -745,7 +745,7 @@ class ListOfListsBuilder:
arg_decls, count_sharing=None, devices=None,
name_prefix="plb_build_list", options=[], preamble="",
debug=False, complex_kernel=False,
eliminate_empty_output_lists=False):
eliminate_empty_output_lists=[]):
"""
:arg context: A :class:`pyopencl.Context`.
:arg list_names_and_dtypes: a list of `(name, dtype)` tuples
@@ -760,6 +760,8 @@ class ListOfListsBuilder:
:arg options: OpenCL compilation options for kernels using
*generate_template*.
:arg complex_kernel: If `True`, prevents vectorization on CPUs.
:arg eliminate_empty_output_lists: A Python list of list names
for which the empty output lists are eliminated.
*generate_template* may use the following C macros/identifiers:
@@ -819,7 +821,13 @@ class ListOfListsBuilder:
self.debug = debug
self.complex_kernel = complex_kernel
self.eliminate_empty_output_lists = eliminate_empty_output_lists
for list_name in self.eliminate_empty_output_lists:
if not any(list_name == name for name, _ in self.list_names_and_dtypes):
raise ValueError(
"invalid list name '%s' in eliminate_empty_output_lists"
% list_name)
# {{{ kernel generators
@@ -950,7 +958,7 @@ class ListOfListsBuilder:
kernel_list_args.append(
VectorArg(index_dtype, "plb_%s_start_index" % name))
if self.eliminate_empty_output_lists:
if name in self.eliminate_empty_output_lists:
kernel_list_args.append(
VectorArg(index_dtype, "%s_compressed_indices" % name))
@@ -1030,9 +1038,9 @@ class ListOfListsBuilder:
This implies that all lists are contiguous.
If the *eliminate_empty_output_lists* constructor argument is set to
True, *lists* has two additional attributes ``num_nonempty_lists`` and
``nonempty_indices``
If the list name is specified in *eliminate_empty_output_lists*
constructor argument, *lists* has two additional attributes
``num_nonempty_lists`` and ``nonempty_indices``
* ``num_nonempty_lists`` for the number of nonempty lists.
* ``nonempty_indices`` for the index of nonempty list in input objects.
@@ -1091,7 +1099,7 @@ class ListOfListsBuilder:
# The scan will turn the "counts" array into the "starts" array
# in-place.
if self.eliminate_empty_output_lists:
if name in self.eliminate_empty_output_lists:
result[name] = BuiltList(count=None, starts=counts, lists=None,
num_nonempty_lists=None,
nonempty_indices=None)
@@ -1115,33 +1123,34 @@ class ListOfListsBuilder:
*(tuple(count_list_args) + args + (n_objects,)),
**dict(wait_for=wait_for))
if self.eliminate_empty_output_lists:
compress_events = {}
for name, dtype in self.list_names_and_dtypes:
if name in omit_lists:
continue
if name in self.count_sharing:
continue
compressed_counts = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator)
info_record = result[name]
info_record.nonempty_indices = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator)
info_record.num_nonempty_lists = cl.array.empty(
queue, (1,), index_dtype, allocator=allocator)
info_record.compressed_indices = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator)
info_record.compressed_indices[0] = 0
compress_events[name] = compress_kernel(
info_record.starts,
compressed_counts,
info_record.nonempty_indices,
info_record.compressed_indices,
info_record.num_nonempty_lists,
wait_for=[count_event] + info_record.compressed_indices.events)
info_record.starts = compressed_counts
compress_events = {}
for name, dtype in self.list_names_and_dtypes:
if name in omit_lists:
continue
if name in self.count_sharing:
continue
if name not in self.eliminate_empty_output_lists:
continue
compressed_counts = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator)
info_record = result[name]
info_record.nonempty_indices = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator)
info_record.num_nonempty_lists = cl.array.empty(
queue, (1,), index_dtype, allocator=allocator)
info_record.compressed_indices = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator)
info_record.compressed_indices[0] = 0
compress_events[name] = compress_kernel(
info_record.starts,
compressed_counts,
info_record.nonempty_indices,
info_record.compressed_indices,
info_record.num_nonempty_lists,
wait_for=[count_event] + info_record.compressed_indices.events)
info_record.starts = compressed_counts
# {{{ run scans
@@ -1154,7 +1163,7 @@ class ListOfListsBuilder:
continue
info_record = result[name]
if self.eliminate_empty_output_lists:
if name in self.eliminate_empty_output_lists:
compress_events[name].wait()
num_nonempty_lists = info_record.num_nonempty_lists.get()[0]
info_record.num_nonempty_lists = num_nonempty_lists
@@ -1164,7 +1173,7 @@ class ListOfListsBuilder:
info_record.starts[-1] = 0
starts_ary = info_record.starts
if self.eliminate_empty_output_lists:
if name in self.eliminate_empty_output_lists:
evt = scan_kernel(
starts_ary,
size=info_record.num_nonempty_lists,
@@ -1209,7 +1218,7 @@ class ListOfListsBuilder:
if name not in self.count_sharing:
write_list_args.append(info_record.starts.data)
if self.eliminate_empty_output_lists:
if name in self.eliminate_empty_output_lists:
write_list_args.append(info_record.compressed_indices.data)
# }}}
Loading