From 5d983b09f87b27e6474f21f6d04d38f11550149c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 24 Jul 2013 20:18:20 -0500 Subject: [PATCH] Fix lmem usage estimate in scan, parametrize scan test --- pyopencl/scan.py | 31 +++++++++++++++++++++------- test/test_algorithm.py | 47 +++++++++++++++++++----------------------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/pyopencl/scan.py b/pyopencl/scan.py index ae380ff2..41dd7912 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -143,7 +143,7 @@ void ${name_prefix}_scan_intervals( ) { // index K in first dimension used for carry storage - %if scan_dtype.itemsize > 4 and scan_dtype.itemsize % 8 == 0 and is_gpu: + %if use_bank_conflict_avoidance: // Avoid bank conflicts by adding a single 32-bit value to the size of // the scan type. struct __attribute__ ((__packed__)) wrapped_scan_type @@ -1064,7 +1064,10 @@ class GenericScanKernel(_GenericScanKernelBase): dev.local_mem_size for dev in self.devices) - if self.devices[0].type == cl.device_type.CPU: + is_cpu = self.devices[0].type & cl.device_type.CPU + is_gpu = self.devices[0].type & cl.device_type.GPU + + if is_cpu: # (about the widest vector a CPU can support, also taking # into account that CPUs don't hide latency by large work groups max_scan_wg_size = 16 @@ -1073,6 +1076,9 @@ class GenericScanKernel(_GenericScanKernelBase): max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices) wg_size_multiples = 64 + use_bank_conflict_avoidance = ( + self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0 and is_gpu) + # k_group_size should be a power of two because of in-kernel # division by that number. @@ -1082,11 +1088,12 @@ class GenericScanKernel(_GenericScanKernelBase): wg_size_multiples): k_group_size = 2**k_exp - lmem_use = self.get_local_mem_use(wg_size, k_group_size) + lmem_use = self.get_local_mem_use(wg_size, k_group_size, + use_bank_conflict_avoidance) if lmem_use + 256 <= avail_local_mem: solutions.append((wg_size*k_group_size, k_group_size, wg_size)) - if self.devices[0].type & cl.device_type.GPU: + if is_gpu: from pytools import any for wg_size_floor in [256, 192, 128]: have_sol_above_floor = any(wg_size >= wg_size_floor @@ -1109,7 +1116,8 @@ class GenericScanKernel(_GenericScanKernelBase): input_fetch_exprs=self.input_fetch_exprs, is_first_level=True, store_segment_start_flags=self.store_segment_start_flags, - k_group_size=k_group_size) + k_group_size=k_group_size, + use_bank_conflict_avoidance=use_bank_conflict_avoidance) # Will this device actually let us execute this kernel # at the desired work group size? Building it is the @@ -1164,6 +1172,7 @@ class GenericScanKernel(_GenericScanKernelBase): is_first_level=False, store_segment_start_flags=False, k_group_size=k_group_size, + use_bank_conflict_avoidance=use_bank_conflict_avoidance, **second_level_build_kwargs) # }}} @@ -1202,7 +1211,7 @@ class GenericScanKernel(_GenericScanKernelBase): # {{{ scan kernel build/properties - def get_local_mem_use(self, k_group_size, wg_size): + def get_local_mem_use(self, k_group_size, wg_size, use_bank_conflict_avoidance): arg_dtypes = {} for arg in self.parsed_args: arg_dtypes[arg.name] = arg.dtype @@ -1211,9 +1220,13 @@ class GenericScanKernel(_GenericScanKernelBase): for name, arg_name, ife_offset in self.input_fetch_exprs: fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset) + itemsize = self.dtype.itemsize + if use_bank_conflict_avoidance: + itemsize += 4 + return ( # ldata - self.dtype.itemsize*(k_group_size+1)*(wg_size+1) + itemsize*(k_group_size+1)*(wg_size+1) # l_segment_start_flags + k_group_size*wg_size @@ -1228,7 +1241,8 @@ class GenericScanKernel(_GenericScanKernelBase): def build_scan_kernel(self, max_wg_size, arguments, input_expr, is_segment_start_expr, input_fetch_exprs, is_first_level, - store_segment_start_flags, k_group_size): + store_segment_start_flags, k_group_size, + use_bank_conflict_avoidance): scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments) # Empirically found on Nv hardware: no need to be bigger than this size @@ -1245,6 +1259,7 @@ class GenericScanKernel(_GenericScanKernelBase): input_fetch_exprs=input_fetch_exprs, is_first_level=is_first_level, store_segment_start_flags=store_segment_start_flags, + use_bank_conflict_avoidance=use_bank_conflict_avoidance, **self.code_variables)) prg = cl.Program(self.context, scan_src).build(self.options) diff --git a/test/test_algorithm.py b/test/test_algorithm.py index bcc687e2..7f0f9f4c 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -486,44 +486,39 @@ scan_test_counts = [ ] -def test_scan(ctx_factory): +@pytest.mark.parametrize("dtype", [np.int32, np.int64]) +@pytest.mark.parametrize("scan_cls", [InclusiveScanKernel, ExclusiveScanKernel]) +def test_scan(ctx_factory, dtype, scan_cls): from pytest import importorskip importorskip("mako") context = ctx_factory() queue = cl.CommandQueue(context) - from pyopencl.scan import InclusiveScanKernel, ExclusiveScanKernel + knl = scan_cls(context, dtype, "a+b", "0") - dtype = np.int32 - for cls in [ - InclusiveScanKernel, - ExclusiveScanKernel - ]: - knl = cls(context, dtype, "a+b", "0") - - for n in scan_test_counts: - host_data = np.random.randint(0, 10, n).astype(dtype) - dev_data = cl_array.to_device(queue, host_data) + for n in scan_test_counts: + host_data = np.random.randint(0, 10, n).astype(dtype) + dev_data = cl_array.to_device(queue, host_data) - # /!\ fails on Nv GT2?? for some drivers - assert (host_data == dev_data.get()).all() + # /!\ fails on Nv GT2?? for some drivers + assert (host_data == dev_data.get()).all() - knl(dev_data) + knl(dev_data) - desired_result = np.cumsum(host_data, axis=0) - if cls is ExclusiveScanKernel: - desired_result -= host_data + desired_result = np.cumsum(host_data, axis=0) + if scan_cls is ExclusiveScanKernel: + desired_result -= host_data - is_ok = (dev_data.get() == desired_result).all() - if 1 and not is_ok: - print("something went wrong, summarizing error...") - print(summarize_error(dev_data.get(), desired_result, host_data)) + is_ok = (dev_data.get() == desired_result).all() + if 1 and not is_ok: + print("something went wrong, summarizing error...") + print(summarize_error(dev_data.get(), desired_result, host_data)) - print("n:%d %s worked:%s" % (n, cls, is_ok)) - assert is_ok - from gc import collect - collect() + print("dtype:%s n:%d %s worked:%s" % (dtype, n, scan_cls, is_ok)) + assert is_ok + from gc import collect + collect() def test_copy_if(ctx_factory): -- GitLab