From 2b0ddacc16a8e5469216615537910581b3353ba8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 15 Jul 2015 15:58:26 -0500 Subject: [PATCH] More interface restructuring --- pyopencl/bitonic_sort.py | 79 +++++++++++++++++++--------------------- test/test_algorithm.py | 10 +++-- 2 files changed, 43 insertions(+), 46 deletions(-) diff --git a/pyopencl/bitonic_sort.py b/pyopencl/bitonic_sort.py index 1408168f..a1d9b7e2 100644 --- a/pyopencl/bitonic_sort.py +++ b/pyopencl/bitonic_sort.py @@ -42,6 +42,8 @@ from functools import reduce from pytools import memoize_method from mako.template import Template +import pyopencl.bitonic_sort_templates as _tmpl + def _is_power_of_2(n): from pyopencl.tools import bitlog2 @@ -56,36 +58,28 @@ class BitonicSort(object): .. versionadded:: 2015.2 """ - def __init__(self, context, shape, key_dtype, idx_dtype=None, axis=0): - import pyopencl.bitonic_sort_templates as tmpl - - self.cached_defs = {} - self.kernels_srcs = { - 'B2': tmpl.ParallelBitonic_B2, - 'B4': tmpl.ParallelBitonic_B4, - 'B8': tmpl.ParallelBitonic_B8, - 'B16': tmpl.ParallelBitonic_B16, - 'C4': tmpl.ParallelBitonic_C4, - 'BL': tmpl.ParallelBitonic_Local, - 'BLO': tmpl.ParallelBitonic_Local_Optim, - 'PML': tmpl.ParallelMerge_Local - } + kernels_srcs = { + 'B2': _tmpl.ParallelBitonic_B2, + 'B4': _tmpl.ParallelBitonic_B4, + 'B8': _tmpl.ParallelBitonic_B8, + 'B16': _tmpl.ParallelBitonic_B16, + 'C4': _tmpl.ParallelBitonic_C4, + 'BL': _tmpl.ParallelBitonic_Local, + 'BLO': _tmpl.ParallelBitonic_Local_Optim, + 'PML': _tmpl.ParallelMerge_Local + } + + def __init__(self, context, key_dtype, idx_dtype=None): self.dtype = dtype_to_ctype(key_dtype) self.context = context - self.axis = axis if idx_dtype is None: - self.argsort = 0 self.idx_t = 'uint' # Dummy else: - self.argsort = 1 self.idx_t = dtype_to_ctype(idx_dtype) - self.defstpl = Template(tmpl.defines) - self.run_queue = self.sort_b_prepare_wl(shape, self.axis) - - def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None): + def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None, axis=0): if queue is None: queue = arr.queue @@ -95,18 +89,23 @@ class BitonicSort(object): last_evt = cl.enqueue_marker(queue, wait_for=wait_for) - if arr.shape[self.axis] == 0: + if arr.shape[axis] == 0: return arr, last_evt - if not _is_power_of_2(arr.shape[self.axis]): + if not _is_power_of_2(arr.shape[axis]): raise ValueError("sorted array axis length must be a power of 2") arr = arr.copy() if mkcpy else arr - run_queue = self.run_queue + if idx is None: + argsort = 0 + else: + argsort = 1 + + run_queue = self.sort_b_prepare_wl(argsort, arr.shape, axis) knl, nt, wg, aux = run_queue[0] - if self.argsort and idx is not None: + if idx is not None: if aux: last_evt = knl( queue, (nt,), wg, arr.data, idx.data, @@ -118,7 +117,7 @@ class BitonicSort(object): queue, (nt,), wg, arr.data, idx.data, wait_for=[last_evt]) - elif not self.argsort: + else: if aux: last_evt = knl( queue, (nt,), wg, arr.data, @@ -127,29 +126,24 @@ class BitonicSort(object): for knl, nt, wg, _ in run_queue[1:]: last_evt = knl(queue, (nt,), wg, arr.data, wait_for=[last_evt]) - else: - raise ValueError("Array of indexes required for this sorter. If argsort is not needed,\ - recreate sorter witout index datatype provided.") return arr, last_evt @memoize_method - def get_program(self, letter, params): - if params in self.cached_defs.keys(): - defs = self.cached_defs[params] - else: - defs = self.defstpl.render( - NS="\\", argsort=self.argsort, inc=params[0], dir=params[1], - dtype=params[2], idxtype=params[3], - dsize=params[4], nsize=params[5]) + def get_program(self, letter, argsort, params): + defstpl = Template(_tmpl.defines) - self.cached_defs[params] = defs + defs = defstpl.render( + NS="\\", argsort=argsort, inc=params[0], dir=params[1], + dtype=params[2], idxtype=params[3], + dsize=params[4], nsize=params[5]) - kid = Template(self.kernels_srcs[letter]).render(argsort=self.argsort) + kid = Template(self.kernels_srcs[letter]).render(argsort=argsort) prg = cl.Program(self.context, defs + kid).build() return prg - def sort_b_prepare_wl(self, shape, axis): + @memoize_method + def sort_b_prepare_wl(self, argsort, shape, axis): run_queue = [] ds = int(shape[axis]) size = reduce(mul, shape) @@ -164,7 +158,8 @@ class BitonicSort(object): wg = min(ds, self.context.devices[0].max_work_group_size) length = wg >> 1 - prg = self.get_program('BLO', (1, 1, self.dtype, self.idx_t, ds, ns)) + prg = self.get_program( + 'BLO', argsort, (1, 1, self.dtype, self.idx_t, ds, ns)) run_queue.append((prg.run, size, (wg,), True)) while length < ds: @@ -187,7 +182,7 @@ class BitonicSort(object): nthreads = size >> ninc - prg = self.get_program(letter, + prg = self.get_program(letter, argsort, (inc, direction, self.dtype, self.idx_t, ds, ns)) run_queue.append((prg.run, nthreads, None, False,)) inc >>= ninc diff --git a/test/test_algorithm.py b/test/test_algorithm.py index 56c649e9..cc32b1cb 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -848,6 +848,7 @@ def test_key_value_sorter(ctx_factory): np.float32, # np.float64 ]) +@pytest.mark.bitonic def test_bitonic_sort(ctx_factory, size, dtype): ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -856,8 +857,8 @@ def test_bitonic_sort(ctx_factory, size, dtype): from pyopencl.bitonic_sort import BitonicSort s = clrandom.rand(queue, (2, size, 3,), dtype, luxury=None, a=0, b=1.0) - sorter = BitonicSort(ctx, s.shape, s.dtype, axis=1) - sgs, evt = sorter(s) + sorter = BitonicSort(ctx, s.dtype) + sgs, evt = sorter(s, axis=1) assert np.array_equal(np.sort(s.get(), axis=1), sgs.get()) @@ -872,6 +873,7 @@ def test_bitonic_sort(ctx_factory, size, dtype): np.float32, # np.float64 ]) +@pytest.mark.bitonic def test_bitonic_argsort(ctx_factory, size, dtype): ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -882,9 +884,9 @@ def test_bitonic_argsort(ctx_factory, size, dtype): index = cl_array.arange(queue, 0, size, 1, dtype=np.int32) m = clrandom.rand(queue, (size,), np.float32, luxury=None, a=0, b=1.0) - sorterm = BitonicSort(ctx, m.shape, m.dtype, idx_dtype=index.dtype, axis=0) + sorterm = BitonicSort(ctx, m.dtype, idx_dtype=index.dtype) - ms, evt = sorterm(m, idx=index) + ms, evt = sorterm(m, idx=index, axis=0) assert np.array_equal(np.sort(m.get()), ms.get()) -- GitLab