Skip to content
Snippets Groups Projects
Commit 4a900040 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Speed up write-out of scan data.

parent 7f06cf21
Branches
Tags
No related merge requests found
...@@ -55,12 +55,12 @@ SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL// ...@@ -55,12 +55,12 @@ SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
${preamble} ${preamble}
typedef ${scan_ctype} scan_type; typedef ${dtype_to_ctype(scan_dtype)} scan_type;
typedef ${index_ctype} index_type; typedef ${dtype_to_ctype(index_dtype)} index_type;
// NO_SEG_BOUNDARY is the largest representable integer in index_type. // NO_SEG_BOUNDARY is the largest representable integer in index_type.
// This assumption is used in code below. // This assumption is used in code below.
#define NO_SEG_BOUNDARY ${index_type_max} #define NO_SEG_BOUNDARY ${str(np.iinfo(index_dtype).max)}
""" """
# }}} # }}}
...@@ -484,18 +484,44 @@ void ${name_prefix}_scan_intervals( ...@@ -484,18 +484,44 @@ void ${name_prefix}_scan_intervals(
// {{{ write data // {{{ write data
for (index_type k = 0; k < K; k++) // work hard with index math to achieve contiguous 32-bit stores
{ {
const index_type offset = k*WG_SIZE + LID_0; __global int *dest = (__global int *) (partial_scan_buffer + unit_base);
%if is_tail: <%
if (unit_base + offset < interval_end)
%endif assert scan_dtype.itemsize % 4 == 0
{
pycl_printf(("write: %d\n", unit_base + offset)); ints_per_wg = wg_size
partial_scan_buffer[unit_base + offset] = ints_to_store = scan_dtype.itemsize*wg_size*k_group_size // 4
ldata[offset % K][offset / K];
} %>
const index_type scan_types_per_int = ${scan_dtype.itemsize//4};
%for store_base in xrange(0, ints_to_store, ints_per_wg):
<%
# Observe that ints_to_store is divisible by the work group size already,
# so we won't go out of bounds that way.
assert store_base + ints_per_wg <= ints_to_store
%>
%if is_tail:
if (${store_base} + LID_0 < scan_types_per_int*(interval_end - unit_base))
%endif
{
index_type linear_index = ${store_base} + LID_0;
index_type linear_scan_data_idx = linear_index / scan_types_per_int;
index_type remainder = linear_index - linear_scan_data_idx * scan_types_per_int;
__local int *src = (__local int *) &(
ldata[linear_scan_data_idx % K][linear_scan_data_idx / K]);
dest[linear_index] = src[remainder];
}
%endfor
} }
pycl_printf(("after write\n")); pycl_printf(("after write\n"));
...@@ -508,15 +534,15 @@ void ${name_prefix}_scan_intervals( ...@@ -508,15 +534,15 @@ void ${name_prefix}_scan_intervals(
% endfor % endfor
// write interval sum // write interval sum
if (LID_0 == 0) %if is_first_level:
{ if (LID_0 == 0)
%if is_first_level: {
interval_results[GID_0] = partial_scan_buffer[interval_end - 1]; interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
%endif %if is_segmented:
%if is_segmented and is_first_level: g_first_segment_start_in_interval[GID_0] = first_segment_start_in_interval;
g_first_segment_start_in_interval[GID_0] = first_segment_start_in_interval; %endif
%endif }
} %endif
} }
""" """
...@@ -713,6 +739,8 @@ _PREFIX_WORDS = set(""" ...@@ -713,6 +739,8 @@ _PREFIX_WORDS = set("""
l_ o_mod_k o_div_k l_segment_start_flags scan_value sum l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
first_seg_start_in_interval g_segment_start_flags first_seg_start_in_interval g_segment_start_flags
group_base seg_end my_val DEBUG ARGS group_base seg_end my_val DEBUG ARGS
ints_to_store ints_per_wg scan_types_per_int linear_index
linear_scan_data_idx dest src store_base
LID_2 LID_1 LID_0 LID_2 LID_1 LID_0
LDIM_0 LDIM_1 LDIM_2 LDIM_0 LDIM_1 LDIM_2
...@@ -721,8 +749,11 @@ _PREFIX_WORDS = set(""" ...@@ -721,8 +749,11 @@ _PREFIX_WORDS = set("""
""".split()) """.split())
_IGNORED_WORDS = set(""" _IGNORED_WORDS = set("""
4 32
typedef for endfor if void while endwhile endfor endif else const printf typedef for endfor if void while endwhile endfor endif else const printf
None return bool n char true false ifdef pycl_printf None return bool n char true false ifdef pycl_printf str xrange assert
np iinfo max itemsize
set iteritems len setdefault set iteritems len setdefault
...@@ -754,17 +785,18 @@ _IGNORED_WORDS = set(""" ...@@ -754,17 +785,18 @@ _IGNORED_WORDS = set("""
intra Therefore find code assumption intra Therefore find code assumption
branch workgroup complicated granularity phase remainder than simpler branch workgroup complicated granularity phase remainder than simpler
We smaller look ifs lots self behind allow barriers whole loop We smaller look ifs lots self behind allow barriers whole loop
after after already Observe achieve contiguous stores hard go with by math
size won t way divisible bit so
is_tail is_first_level input_expr argument_signature preamble is_tail is_first_level input_expr argument_signature preamble
double_support neutral output_statement index_type_max double_support neutral output_statement
k_group_size name_prefix is_segmented index_ctype scan_ctype k_group_size name_prefix is_segmented index_dtype scan_dtype
wg_size is_segment_start_expr fetch_expr_offsets wg_size is_segment_start_expr fetch_expr_offsets
arg_ctypes ife_offsets input_fetch_exprs def arg_ctypes ife_offsets input_fetch_exprs def
ife_offset arg_name local_fetch_expr_args update_body ife_offset arg_name local_fetch_expr_args update_body
update_loop_lookbehind update_loop_plain update_loop update_loop_lookbehind update_loop_plain update_loop
use_lookbehind_update store_segment_start_flags use_lookbehind_update store_segment_start_flags
update_loop first_seg update_loop first_seg scan_dtype dtype_to_ctype
a b prev_item i last_item prev_value a b prev_item i last_item prev_value
N NO_SEG_BOUNDARY across_seg_boundary N NO_SEG_BOUNDARY across_seg_boundary
...@@ -801,6 +833,9 @@ class _ScanKernelInfo(Record): ...@@ -801,6 +833,9 @@ class _ScanKernelInfo(Record):
# }}} # }}}
class ScanPerformanceWarning(UserWarning):
pass
class _GenericScanKernelBase(object): class _GenericScanKernelBase(object):
# {{{ constructor, argument processing # {{{ constructor, argument processing
...@@ -901,6 +936,9 @@ class _GenericScanKernelBase(object): ...@@ -901,6 +936,9 @@ class _GenericScanKernelBase(object):
"'output_statement' otherwise does something non-trivial", "'output_statement' otherwise does something non-trivial",
stacklevel=2) stacklevel=2)
if dtype.itemsize % 4 != 0:
raise TypeError("scan value type must have size divisible by 4 bytes")
self.index_dtype = np.dtype(index_dtype) self.index_dtype = np.dtype(index_dtype)
if np.iinfo(self.index_dtype).min >= 0: if np.iinfo(self.index_dtype).min >= 0:
raise TypeError("index_dtype must be signed") raise TypeError("index_dtype must be signed")
...@@ -930,8 +968,10 @@ class _GenericScanKernelBase(object): ...@@ -930,8 +968,10 @@ class _GenericScanKernelBase(object):
raise RuntimeError("input_fetch_expr offsets must either be 0 or -1") raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
self.input_fetch_exprs = input_fetch_exprs self.input_fetch_exprs = input_fetch_exprs
arg_dtypes = {}
arg_ctypes = {} arg_ctypes = {}
for arg in self.parsed_args: for arg in self.parsed_args:
arg_dtypes[arg.name] = arg.dtype
arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype) arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
self.options = options self.options = options
...@@ -943,12 +983,14 @@ class _GenericScanKernelBase(object): ...@@ -943,12 +983,14 @@ class _GenericScanKernelBase(object):
from pyopencl.characterize import has_double_support from pyopencl.characterize import has_double_support
self.code_variables = dict( self.code_variables = dict(
np=np,
dtype_to_ctype=dtype_to_ctype,
preamble=preamble, preamble=preamble,
name_prefix=name_prefix, name_prefix=name_prefix,
index_ctype=dtype_to_ctype(self.index_dtype), index_dtype=self.index_dtype,
index_type_max=str(np.iinfo(self.index_dtype).max), scan_dtype=dtype,
scan_ctype=dtype_to_ctype(dtype),
is_segmented=self.is_segmented, is_segmented=self.is_segmented,
arg_dtypes=arg_dtypes,
arg_ctypes=arg_ctypes, arg_ctypes=arg_ctypes,
scan_expr=_process_code_for_macro(scan_expr), scan_expr=_process_code_for_macro(scan_expr),
neutral=_process_code_for_macro(neutral), neutral=_process_code_for_macro(neutral),
...@@ -1008,7 +1050,7 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1008,7 +1050,7 @@ class GenericScanKernel(_GenericScanKernelBase):
for _, _, wg_size in solutions) for _, _, wg_size in solutions)
if have_sol_above_floor: if have_sol_above_floor:
# delete all the others # delete all solutions not meeting the wg size floor
solutions = [(total, k_group_size, wg_size) solutions = [(total, k_group_size, wg_size)
for total, k_group_size, wg_size in solutions for total, k_group_size, wg_size in solutions
if wg_size >= wg_size_floor] if wg_size >= wg_size_floor]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment