diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 302c0a7901932768a1cc2b8242a2138b69856ea7..dc3be4bed8977b9f33b759b463925204a59ff714 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -650,7 +650,7 @@ void ${kernel_name}(${kernel_list_arg_decl} USER_ARG_DECL index_type n) if (plb_${name}_start_index) %if eliminate_empty_output_lists: plb_${name}_index = - plb_${name}_start_index[plb_${name}_mask_scan[i]]; + plb_${name}_start_index[${name}_compressed_indices[i]]; %else: plb_${name}_index = plb_${name}_start_index[i]; %endif @@ -835,8 +835,8 @@ class ListOfListsBuilder: def get_compress_kernel(self, index_dtype): arguments = """ __global ${index_t} *count, - __global ${index_t} *indices, - __global ${index_t} *mask_scan, + __global ${index_t} *nonempty_indices, + __global ${index_t} *compressed_indices, __global ${index_t} *num_non_empty_list """ from sys import version_info @@ -852,9 +852,9 @@ class ListOfListsBuilder: input_expr="count[i] == 0 ? 0 : 1", scan_expr="a+b", neutral="0", output_statement=""" - mask_scan[i + 1] = item; + compressed_indices[i + 1] = item; if (prev_item != item) { - indices[item - 1] = i; + nonempty_indices[item - 1] = i; count[item - 1] = count[i]; } if (i + 1 == N) *num_non_empty_list = item; @@ -948,7 +948,7 @@ class ListOfListsBuilder: if self.eliminate_empty_output_lists: kernel_list_args.append( - VectorArg(index_dtype, "plb_%s_mask_scan" % name)) + VectorArg(index_dtype, "%s_compressed_indices" % name)) index_name = "plb_%s_index" % name user_list_args.append(OtherArg("%s *%s" % ( @@ -1026,15 +1026,16 @@ class ListOfListsBuilder: This implies that all lists are contiguous. - If eliminate_empty_output_lists is set to True, *lists* has two - additional attributes ``num_nonempty_lists`` and ``indices`` + If the *eliminate_empty_output_lists* constructor argument is set to + True, *lists* has two additional attributes ``num_nonempty_lists`` and + ``nonempty_indices`` * ``num_nonempty_lists`` for the number of nonempty lists. - * ``indices`` for the index of nonempty list in input objects. + * ``nonempty_indices`` for the index of nonempty list in input objects. In this case, `starts` has `num_nonempty_lists` + 1 entries. The *i*'s entry is the start of the *i*'th nonempty list, which is generated by - the object with index *indices[i]*. + the object with index *nonempty_indices[i]*. *event* is a :class:`pyopencl.Event` for dependency management. @@ -1114,19 +1115,19 @@ class ListOfListsBuilder: continue info_record = result[name] - info_record.indices = cl.array.empty( + info_record.nonempty_indices = cl.array.empty( queue, (n_objects + 1,), index_dtype, allocator=allocator) info_record.num_nonempty_lists = cl.array.empty( queue, (1,), index_dtype, allocator=allocator) - info_record.mask_scan = cl.array.empty( + info_record.compressed_indices = cl.array.empty( queue, (n_objects + 1,), index_dtype, allocator=allocator) - info_record.mask_scan[0] = 0 + info_record.compressed_indices[0] = 0 compress_events[name] = compress_kernel( info_record.starts, - info_record.indices, - info_record.mask_scan, + info_record.nonempty_indices, + info_record.compressed_indices, info_record.num_nonempty_lists, - wait_for=[count_event] + info_record.mask_scan.events) + wait_for=[count_event] + info_record.compressed_indices.events) # {{{ run scans @@ -1144,7 +1145,8 @@ class ListOfListsBuilder: num_nonempty_lists = info_record.num_nonempty_lists.get()[0] info_record.num_nonempty_lists = num_nonempty_lists info_record.starts = info_record.starts[:num_nonempty_lists + 1] - info_record.indices = info_record.indices[:num_nonempty_lists] + info_record.nonempty_indices = \ + info_record.nonempty_indices[:num_nonempty_lists] info_record.starts[-1] = 0 starts_ary = info_record.starts @@ -1194,7 +1196,7 @@ class ListOfListsBuilder: write_list_args.append(info_record.starts.data) if self.eliminate_empty_output_lists: - write_list_args.append(info_record.mask_scan.data) + write_list_args.append(info_record.compressed_indices.data) # }}} diff --git a/test/test_algorithm.py b/test/test_algorithm.py index 3846d1860b4bd5007397835f5d3fd735b846962b..ee32278bbf1c48142eb6dbf78963c03b3a9e313e 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -871,7 +871,7 @@ def test_list_builder(ctx_factory): mylist1 = result["mylist1"] assert mylist1.count == 19900 assert (mylist1.starts.get()[:5] == [0, 1, 3, 6, 10]).all() - assert (mylist1.indices.get()[:5] == [5, 10, 15, 20, 25]).all() + assert (mylist1.nonempty_indices.get()[:5] == [5, 10, 15, 20, 25]).all() assert (mylist1.lists.get()[:6] == [0, 0, 1, 0, 1, 2]).all() mylist2 = result["mylist2"] assert mylist2.count == 19900