Skip to content
Snippets Groups Projects
Commit cf9c65cc authored by Hao Gao's avatar Hao Gao
Browse files

Fix list builder bug. See #5, #7

parent 59f15020
No related branches found
No related tags found
Loading
Pipeline #
...@@ -838,6 +838,7 @@ class ListOfListsBuilder: ...@@ -838,6 +838,7 @@ class ListOfListsBuilder:
def get_compress_kernel(self, index_dtype): def get_compress_kernel(self, index_dtype):
arguments = """ arguments = """
__global ${index_t} *count, __global ${index_t} *count,
__global ${index_t} *compressed_counts,
__global ${index_t} *nonempty_indices, __global ${index_t} *nonempty_indices,
__global ${index_t} *compressed_indices, __global ${index_t} *compressed_indices,
__global ${index_t} *num_non_empty_list __global ${index_t} *num_non_empty_list
...@@ -858,7 +859,7 @@ class ListOfListsBuilder: ...@@ -858,7 +859,7 @@ class ListOfListsBuilder:
compressed_indices[i + 1] = item; compressed_indices[i + 1] = item;
if (prev_item != item) { if (prev_item != item) {
nonempty_indices[item - 1] = i; nonempty_indices[item - 1] = i;
count[item - 1] = count[i]; compressed_counts[item - 1] = count[i];
} }
if (i + 1 == N) *num_non_empty_list = item; if (i + 1 == N) *num_non_empty_list = item;
""", """,
...@@ -1122,6 +1123,8 @@ class ListOfListsBuilder: ...@@ -1122,6 +1123,8 @@ class ListOfListsBuilder:
if name in self.count_sharing: if name in self.count_sharing:
continue continue
compressed_counts = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator)
info_record = result[name] info_record = result[name]
info_record.nonempty_indices = cl.array.empty( info_record.nonempty_indices = cl.array.empty(
queue, (n_objects + 1,), index_dtype, allocator=allocator) queue, (n_objects + 1,), index_dtype, allocator=allocator)
...@@ -1132,11 +1135,14 @@ class ListOfListsBuilder: ...@@ -1132,11 +1135,14 @@ class ListOfListsBuilder:
info_record.compressed_indices[0] = 0 info_record.compressed_indices[0] = 0
compress_events[name] = compress_kernel( compress_events[name] = compress_kernel(
info_record.starts, info_record.starts,
compressed_counts,
info_record.nonempty_indices, info_record.nonempty_indices,
info_record.compressed_indices, info_record.compressed_indices,
info_record.num_nonempty_lists, info_record.num_nonempty_lists,
wait_for=[count_event] + info_record.compressed_indices.events) wait_for=[count_event] + info_record.compressed_indices.events)
info_record.starts = compressed_counts
# {{{ run scans # {{{ run scans
scan_events = [] scan_events = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment