From cdf69f6163f4a52e9f860ab47669048e20221984 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 4 May 2016 18:18:53 -0500 Subject: [PATCH] Make radix sort work with debug scan kernel --- pyopencl/scan.py | 33 ++++++++++++++++++++++++++------- test/test_algorithm.py | 14 +++++++++----- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 0ca0aa6a..272e6996 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -764,7 +764,7 @@ _PREFIX_WORDS = set(""" group_base seg_end my_val DEBUG ARGS ints_to_store ints_per_wg scan_types_per_int linear_index linear_scan_data_idx dest src store_base wrapped_scan_type - dummy + dummy scan_tmp LID_2 LID_1 LID_0 LDIM_0 LDIM_1 LDIM_2 @@ -1419,11 +1419,12 @@ DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + r"""//CL// KERNEL REQD_WG_SIZE(1, 1, 1) void ${name_prefix}_debug_scan( + __global scan_type *scan_tmp, ${argument_signature}, const index_type N) { - scan_type item = ${neutral}; - scan_type last_item; + scan_type current = ${neutral}; + scan_type prev; for (index_type i = 0; i < N; ++i) { @@ -1439,18 +1440,31 @@ void ${name_prefix}_debug_scan( scan_type my_val = INPUT_EXPR(i); - last_item = item; + prev = current; %if is_segmented: bool is_seg_start = IS_SEG_START(i, my_val); %endif - item = SCAN_EXPR(last_item, my_val, + current = SCAN_EXPR(prev, my_val, %if is_segmented: is_seg_start %else: false %endif ); + scan_tmp[i] = current; + } + + scan_type last_item = scan_tmp[N-1]; + + for (index_type i = 0; i < N; ++i) + { + scan_type item = scan_tmp[i]; + scan_type prev_item; + if (i) + prev_item = scan_tmp[i-1]; + else + prev_item = ${neutral}; { ${output_statement}; @@ -1477,7 +1491,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase): self.kernel = getattr( scan_prg, self.name_prefix+"_debug_scan") scalar_arg_dtypes = ( - get_arg_list_scalar_arg_dtypes(self.parsed_args) + [None] + + get_arg_list_scalar_arg_dtypes(self.parsed_args) + [self.index_dtype]) self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes) @@ -1500,7 +1515,11 @@ class GenericDebugScanKernel(_GenericScanKernelBase): if n is None: n, = first_array.shape - data_args = [] + scan_tmp = cl.array.empty(queue, + n, dtype=self.dtype, + allocator=allocator) + + data_args = [scan_tmp.data] from pyopencl.tools import VectorArg for arg_descr, arg_val in zip(self.parsed_args, args): if isinstance(arg_descr, VectorArg): diff --git a/test/test_algorithm.py b/test/test_algorithm.py index 02e6adcd..952a38dd 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -38,7 +38,8 @@ import pyopencl.array as cl_array # noqa from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) from pyopencl.characterize import has_double_support, has_struct_arg_count_bug -from pyopencl.scan import InclusiveScanKernel, ExclusiveScanKernel +from pyopencl.scan import (InclusiveScanKernel, ExclusiveScanKernel, + GenericScanKernel, GenericDebugScanKernel) # {{{ elementwise @@ -668,7 +669,6 @@ def test_segmented_scan(ctx_factory): else: output_statement = "out[i] = item" - from pyopencl.scan import GenericScanKernel knl = GenericScanKernel(context, dtype, arguments="__global %s *ary, __global char *segflags, " "__global %s *out" % (ctype, ctype), @@ -748,7 +748,8 @@ def test_segmented_scan(ctx_factory): print("%d excl:%s done" % (n, is_exclusive)) -def test_sort(ctx_factory): +@pytest.mark.parametrize("scan_kernel", [GenericScanKernel, GenericDebugScanKernel]) +def test_sort(ctx_factory, scan_kernel): from pytest import importorskip importorskip("mako") @@ -759,7 +760,7 @@ def test_sort(ctx_factory): from pyopencl.algorithm import RadixSort sort = RadixSort(context, "int *ary", key_expr="ary[i]", - sort_arg_names=["ary"]) + sort_arg_names=["ary"], scan_kernel=scan_kernel) from pyopencl.clrandom import RanluxGenerator rng = RanluxGenerator(queue, seed=15) @@ -768,6 +769,9 @@ def test_sort(ctx_factory): # intermediate arrays for largest size cause out-of-memory on low-end GPUs for n in scan_test_counts[:-1]: + if n >= 2000 and isinstance(scan_kernel, GenericDebugScanKernel): + continue + print(n) print(" rng") @@ -785,7 +789,7 @@ def test_sort(ctx_factory): numpy_elapsed = numpy_end-dev_end dev_elapsed = dev_end-dev_start - print (" dev: %.2f MKeys/s numpy: %.2f MKeys/s ratio: %.2fx" % ( + print(" dev: %.2f MKeys/s numpy: %.2f MKeys/s ratio: %.2fx" % ( 1e-6*n/dev_elapsed, 1e-6*n/numpy_elapsed, numpy_elapsed/dev_elapsed)) assert (a_dev_sorted.get() == a_sorted).all() -- GitLab