Skip to content
Snippets Groups Projects

Support vector arguments with offset in scan kernels.

Merged Matt Wala requested to merge vectorarg-with-offset-for-scan into master
Files
3
+ 18
4
@@ -35,7 +35,8 @@ from pyopencl.tools import (dtype_to_ctype, bitlog2,
KernelTemplateBase, _process_code_for_macro,
get_arg_list_scalar_arg_dtypes,
context_dependent_memoize,
_NumpyTypesKeyBuilder)
_NumpyTypesKeyBuilder,
get_arg_offset_adjuster_code)
import pyopencl._mymako as mako
from pyopencl._cluda import CLUDA_PREAMBLE
@@ -148,6 +149,8 @@ void ${kernel_name}(
%endif
)
{
${arg_offset_adjustment}
// index K in first dimension used for carry storage
%if use_bank_conflict_avoidance:
// Avoid bank conflicts by adding a single 32-bit value to the size of
@@ -618,6 +621,8 @@ void ${name_prefix}_final_update(
%endif
)
{
${arg_offset_adjustment}
%if use_lookbehind_update:
LOCAL_MEM scan_type ldata[WG_SIZE];
%endif
@@ -998,7 +1003,7 @@ class _GenericScanKernelBase(object):
resulting in a C `bool` value that determines whether a new
scan segments starts at index *i*. If given, makes the scan a
segmented scan. Has access to the current index `i`, the result
of *input_expr* as a, and in addition may use *arguments* and
of *input_expr* as `a`, and in addition may use *arguments* and
*input_fetch_expr* variables just like *input_expr*.
If it returns true, then previous sums will not spill over into the
@@ -1346,6 +1351,7 @@ class GenericScanKernel(_GenericScanKernelBase):
final_update_src = str(final_update_tpl.render(
wg_size=update_wg_size,
output_statement=self.output_statement,
arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
argument_signature=", ".join(
arg.declarator() for arg in self.parsed_args),
is_segment_start_expr=self.is_segment_start_expr,
@@ -1421,6 +1427,7 @@ class GenericScanKernel(_GenericScanKernelBase):
wg_size=wg_size,
input_expr=input_expr,
k_group_size=k_group_size,
arg_offset_adjustment=get_arg_offset_adjuster_code(arguments),
argument_signature=", ".join(arg.declarator() for arg in arguments),
is_segment_start_expr=is_segment_start_expr,
input_fetch_exprs=input_fetch_exprs,
@@ -1475,7 +1482,9 @@ class GenericScanKernel(_GenericScanKernelBase):
from pyopencl.tools import VectorArg
for arg_descr, arg_val in zip(self.parsed_args, args):
if isinstance(arg_descr, VectorArg):
data_args.append(arg_val.data)
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
else:
data_args.append(arg_val)
@@ -1583,6 +1592,8 @@ void ${name_prefix}_debug_scan(
scan_type current = ${neutral};
scan_type prev;
${arg_offset_adjustment}
for (index_type i = 0; i < N; ++i)
{
%for name, arg_name, ife_offset in input_fetch_exprs:
@@ -1636,6 +1647,7 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
scan_src = str(scan_tpl.render(
output_statement=self.output_statement,
arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
argument_signature=", ".join(
arg.declarator() for arg in self.parsed_args),
is_segment_start_expr=self.is_segment_start_expr,
@@ -1680,7 +1692,9 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
from pyopencl.tools import VectorArg
for arg_descr, arg_val in zip(self.parsed_args, args):
if isinstance(arg_descr, VectorArg):
data_args.append(arg_val.data)
data_args.append(arg_val.base_data)
if arg_descr.with_offset:
data_args.append(arg_val.offset)
else:
data_args.append(arg_val)
Loading