Skip to content
Snippets Groups Projects
Commit 5d983b09 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix lmem usage estimate in scan, parametrize scan test

parent 98135cbe
No related branches found
No related tags found
No related merge requests found
......@@ -143,7 +143,7 @@ void ${name_prefix}_scan_intervals(
)
{
// index K in first dimension used for carry storage
%if scan_dtype.itemsize > 4 and scan_dtype.itemsize % 8 == 0 and is_gpu:
%if use_bank_conflict_avoidance:
// Avoid bank conflicts by adding a single 32-bit value to the size of
// the scan type.
struct __attribute__ ((__packed__)) wrapped_scan_type
......@@ -1064,7 +1064,10 @@ class GenericScanKernel(_GenericScanKernelBase):
dev.local_mem_size
for dev in self.devices)
if self.devices[0].type == cl.device_type.CPU:
is_cpu = self.devices[0].type & cl.device_type.CPU
is_gpu = self.devices[0].type & cl.device_type.GPU
if is_cpu:
# (about the widest vector a CPU can support, also taking
# into account that CPUs don't hide latency by large work groups
max_scan_wg_size = 16
......@@ -1073,6 +1076,9 @@ class GenericScanKernel(_GenericScanKernelBase):
max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
wg_size_multiples = 64
use_bank_conflict_avoidance = (
self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0 and is_gpu)
# k_group_size should be a power of two because of in-kernel
# division by that number.
......@@ -1082,11 +1088,12 @@ class GenericScanKernel(_GenericScanKernelBase):
wg_size_multiples):
k_group_size = 2**k_exp
lmem_use = self.get_local_mem_use(wg_size, k_group_size)
lmem_use = self.get_local_mem_use(wg_size, k_group_size,
use_bank_conflict_avoidance)
if lmem_use + 256 <= avail_local_mem:
solutions.append((wg_size*k_group_size, k_group_size, wg_size))
if self.devices[0].type & cl.device_type.GPU:
if is_gpu:
from pytools import any
for wg_size_floor in [256, 192, 128]:
have_sol_above_floor = any(wg_size >= wg_size_floor
......@@ -1109,7 +1116,8 @@ class GenericScanKernel(_GenericScanKernelBase):
input_fetch_exprs=self.input_fetch_exprs,
is_first_level=True,
store_segment_start_flags=self.store_segment_start_flags,
k_group_size=k_group_size)
k_group_size=k_group_size,
use_bank_conflict_avoidance=use_bank_conflict_avoidance)
# Will this device actually let us execute this kernel
# at the desired work group size? Building it is the
......@@ -1164,6 +1172,7 @@ class GenericScanKernel(_GenericScanKernelBase):
is_first_level=False,
store_segment_start_flags=False,
k_group_size=k_group_size,
use_bank_conflict_avoidance=use_bank_conflict_avoidance,
**second_level_build_kwargs)
# }}}
......@@ -1202,7 +1211,7 @@ class GenericScanKernel(_GenericScanKernelBase):
# {{{ scan kernel build/properties
def get_local_mem_use(self, k_group_size, wg_size):
def get_local_mem_use(self, k_group_size, wg_size, use_bank_conflict_avoidance):
arg_dtypes = {}
for arg in self.parsed_args:
arg_dtypes[arg.name] = arg.dtype
......@@ -1211,9 +1220,13 @@ class GenericScanKernel(_GenericScanKernelBase):
for name, arg_name, ife_offset in self.input_fetch_exprs:
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
itemsize = self.dtype.itemsize
if use_bank_conflict_avoidance:
itemsize += 4
return (
# ldata
self.dtype.itemsize*(k_group_size+1)*(wg_size+1)
itemsize*(k_group_size+1)*(wg_size+1)
# l_segment_start_flags
+ k_group_size*wg_size
......@@ -1228,7 +1241,8 @@ class GenericScanKernel(_GenericScanKernelBase):
def build_scan_kernel(self, max_wg_size, arguments, input_expr,
is_segment_start_expr, input_fetch_exprs, is_first_level,
store_segment_start_flags, k_group_size):
store_segment_start_flags, k_group_size,
use_bank_conflict_avoidance):
scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
# Empirically found on Nv hardware: no need to be bigger than this size
......@@ -1245,6 +1259,7 @@ class GenericScanKernel(_GenericScanKernelBase):
input_fetch_exprs=input_fetch_exprs,
is_first_level=is_first_level,
store_segment_start_flags=store_segment_start_flags,
use_bank_conflict_avoidance=use_bank_conflict_avoidance,
**self.code_variables))
prg = cl.Program(self.context, scan_src).build(self.options)
......
......@@ -486,44 +486,39 @@ scan_test_counts = [
]
def test_scan(ctx_factory):
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
@pytest.mark.parametrize("scan_cls", [InclusiveScanKernel, ExclusiveScanKernel])
def test_scan(ctx_factory, dtype, scan_cls):
from pytest import importorskip
importorskip("mako")
context = ctx_factory()
queue = cl.CommandQueue(context)
from pyopencl.scan import InclusiveScanKernel, ExclusiveScanKernel
knl = scan_cls(context, dtype, "a+b", "0")
dtype = np.int32
for cls in [
InclusiveScanKernel,
ExclusiveScanKernel
]:
knl = cls(context, dtype, "a+b", "0")
for n in scan_test_counts:
host_data = np.random.randint(0, 10, n).astype(dtype)
dev_data = cl_array.to_device(queue, host_data)
for n in scan_test_counts:
host_data = np.random.randint(0, 10, n).astype(dtype)
dev_data = cl_array.to_device(queue, host_data)
# /!\ fails on Nv GT2?? for some drivers
assert (host_data == dev_data.get()).all()
# /!\ fails on Nv GT2?? for some drivers
assert (host_data == dev_data.get()).all()
knl(dev_data)
knl(dev_data)
desired_result = np.cumsum(host_data, axis=0)
if cls is ExclusiveScanKernel:
desired_result -= host_data
desired_result = np.cumsum(host_data, axis=0)
if scan_cls is ExclusiveScanKernel:
desired_result -= host_data
is_ok = (dev_data.get() == desired_result).all()
if 1 and not is_ok:
print("something went wrong, summarizing error...")
print(summarize_error(dev_data.get(), desired_result, host_data))
is_ok = (dev_data.get() == desired_result).all()
if 1 and not is_ok:
print("something went wrong, summarizing error...")
print(summarize_error(dev_data.get(), desired_result, host_data))
print("n:%d %s worked:%s" % (n, cls, is_ok))
assert is_ok
from gc import collect
collect()
print("dtype:%s n:%d %s worked:%s" % (dtype, n, scan_cls, is_ok))
assert is_ok
from gc import collect
collect()
def test_copy_if(ctx_factory):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment