From bcbb113adffcf89dcc75f4a3e32ae4f785c5366c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 15 Jul 2015 15:44:41 -0500 Subject: [PATCH] First steps to incorporate bitonic --- pyopencl/bitonic_sort.py | 154 +++++++++ pyopencl/bitonic_sort_templates.py | 516 +++++++++++++++++++++++++++++ test/test_algorithm.py | 41 ++- 3 files changed, 706 insertions(+), 5 deletions(-) create mode 100644 pyopencl/bitonic_sort.py create mode 100644 pyopencl/bitonic_sort_templates.py diff --git a/pyopencl/bitonic_sort.py b/pyopencl/bitonic_sort.py new file mode 100644 index 00000000..f8431d6e --- /dev/null +++ b/pyopencl/bitonic_sort.py @@ -0,0 +1,154 @@ +from __future__ import division, with_statement, absolute_import, print_function + +__copyright__ = """ +Copyright (c) 2011, Eric Bainville +Copyright (c) 2015, Ilya Efimoff +All rights reserved. +""" + +# based on code at + +__license__ = """ +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors +may be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import pyopencl as cl +from pyopencl.tools import dtype_to_ctype +from mako.template import Template +from operator import mul +from functools import reduce +from pytools import memoize_method + + +class BitonicSort(object): + def __init__(self, context, shape, key_dtype, idx_dtype=None, axis=0): + import pyopencl.bitonic_sort_templates as tmpl + + self.cached_defs = {} + self.kernels_srcs = { + 'B2': tmpl.ParallelBitonic_B2, + 'B4': tmpl.ParallelBitonic_B4, + 'B8': tmpl.ParallelBitonic_B8, + 'B16': tmpl.ParallelBitonic_B16, + 'C4': tmpl.ParallelBitonic_C4, + 'BL': tmpl.ParallelBitonic_Local, + 'BLO': tmpl.ParallelBitonic_Local_Optim, + 'PML': tmpl.ParallelMerge_Local + } + + self.dtype = dtype_to_ctype(key_dtype) + self.context = context + self.axis = axis + if idx_dtype is None: + self.argsort = 0 + self.idx_t = 'uint' # Dummy + else: + self.argsort = 1 + self.idx_t = dtype_to_ctype(idx_dtype) + self.defstpl = Template(tmpl.defines) + self.rq = self.sort_b_prepare_wl(shape, self.axis) + + def __call__(self, _arr, idx=None, mkcpy=True): + arr = _arr.copy() if mkcpy else _arr + rq = self.rq + p, nt, wg, aux = rq[0] + if self.argsort and not type(idx)==type(None): + if aux: + p.run(arr.queue, (nt,), wg, arr.data, idx.data, cl.LocalMemory(wg[0]*4*arr.dtype.itemsize),\ + cl.LocalMemory(wg[0]*4*idx.dtype.itemsize)) + for p, nt, wg,_ in rq[1:]: + p.run(arr.queue, (nt,), wg, arr.data, idx.data) + elif self.argsort==0: + if aux: + p.run(arr.queue, (nt,), wg, arr.data, cl.LocalMemory(wg[0]*4*arr.dtype.itemsize)) + for p, nt, wg,_ in rq[1:]: + p.run(arr.queue, (nt,), wg, arr.data) + else: + raise ValueError("Array of indexes required for this sorter. If argsort is not needed,\ + recreate sorter witout index datatype provided.") + return arr + + @memoize_method + def get_program(self, letter, params): + if params in self.cached_defs.keys(): + defs = self.cached_defs[params] + else: + defs = self.defstpl.render( + NS="\\", argsort=self.argsort, inc=params[0], dir=params[1], + dtype=params[2], idxtype=params[3], + dsize=params[4], nsize=params[5]) + + self.cached_defs[params] = defs + kid = Template(self.kernels_srcs[letter]).render(argsort=self.argsort) + prg = cl.Program(self.context, defs + kid).build() + return prg + + def sort_b_prepare_wl(self, shape, axis): + run_queue = [] + ds = int(shape[axis]) + size = reduce(mul, shape) + ndim = len(shape) + + ns = reduce(mul, shape[(axis+1):]) if axis < ndim-1 else 1 + + ds = int(shape[axis]) + allowb4 = True + allowb8 = True + allowb16 = True + + wg = min(ds, self.context.devices[0].max_work_group_size) + length = wg >> 1 + prg = self.get_program('BLO', (1, 1, self.dtype, self.idx_t, ds, ns)) + run_queue.append((prg, size, (wg,), True)) + + while length < ds: + inc = length + while inc > 0: + ninc = 0 + direction = length << 1 + if allowb16 and inc >= 8 and ninc == 0: + letter = 'B16' + ninc = 4 + elif allowb8 and inc >= 4 and ninc == 0: + letter = 'B8' + ninc = 3 + elif allowb4 and inc >= 2 and ninc == 0: + letter = 'B4' + ninc = 2 + elif inc >= 0: + letter = 'B2' + ninc = 1 + + nthreads = size >> ninc + + prg = self.get_program(letter, + (inc, direction, self.dtype, self.idx_t, ds, ns)) + run_queue.append((prg, nthreads, None, False,)) + inc >>= ninc + + length <<= 1 + + return run_queue diff --git a/pyopencl/bitonic_sort_templates.py b/pyopencl/bitonic_sort_templates.py new file mode 100644 index 00000000..eaaa2a7a --- /dev/null +++ b/pyopencl/bitonic_sort_templates.py @@ -0,0 +1,516 @@ +__copyright__ = """ +Copyright (c) 2011, Eric Bainville +Copyright (c) 2015, Ilya Efimoff +All rights reserved. +""" +__license__ = """ +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +defines = """ +typedef ${dtype} data_t; +typedef ${idxtype} idx_t; +typedef ${idxtype}2 idx_t2; +#if CONFIG_USE_VALUE +#define getKey(a) ((a).x) +#define getValue(a) ((a).y) +#define makeData(k,v) ((${dtype}2)((k),(v))) +#else +#define getKey(a) (a) +#define getValue(a) (0) +#define makeData(k,v) (k) +#endif + +#ifndef BLOCK_FACTOR +#define BLOCK_FACTOR 1 +#endif + +#define inc ${inc} +#define hinc ${inc>>1} //Half inc +#define qinc ${inc>>2} //Quarter inc +#define einc ${inc>>3} //Eighth of inc +#define dir ${dir} + +% if argsort: +#define ORDER(a,b,ay,by) { bool swap = reverse ^ (getKey(a)<getKey(b));${NS} + data_t auxa = a; data_t auxb = b;${NS} + idx_t auya = ay; idx_t auyb = by;${NS} + a = (swap)?auxb:auxa; b = (swap)?auxa:auxb;${NS} + ay = (swap)?auyb:auya; by = (swap)?auya:auyb;} +#define ORDERV(x,y,a,b) { bool swap = reverse ^ (getKey(x[a])<getKey(x[b]));${NS} + data_t auxa = x[a]; data_t auxb = x[b];${NS} + idx_t auya = y[a]; idx_t auyb = y[b];${NS} + x[a] = (swap)?auxb:auxa; x[b] = (swap)?auxa:auxb;${NS} + y[a] = (swap)?auyb:auya; y[b] = (swap)?auya:auyb;} +#define B2V(x,y,a) { ORDERV(x,y,a,a+1) } +#define B4V(x,y,a) { for (int i4=0;i4<2;i4++) { ORDERV(x,y,a+i4,a+i4+2) } B2V(x,y,a) B2V(x,y,a+2) } +#define B8V(x,y,a) { for (int i8=0;i8<4;i8++) { ORDERV(x,y,a+i8,a+i8+4) } B4V(x,y,a) B4V(x,y,a+4) } +#define B16V(x,y,a) { for (int i16=0;i16<8;i16++) { ORDERV(x,y,a+i16,a+i16+8) } B8V(x,y,a) B8V(x,y,a+8) } +% else: +#define ORDER(a,b) { bool swap = reverse ^ (getKey(a)<getKey(b)); data_t auxa = a; data_t auxb = b; a = (swap)?auxb:auxa; b = (swap)?auxa:auxb; } +#define ORDERV(x,a,b) { bool swap = reverse ^ (getKey(x[a])<getKey(x[b]));${NS} + data_t auxa = x[a]; data_t auxb = x[b];${NS} + x[a] = (swap)?auxb:auxa; x[b] = (swap)?auxa:auxb; } +#define B2V(x,a) { ORDERV(x,a,a+1) } +#define B4V(x,a) { for (int i4=0;i4<2;i4++) { ORDERV(x,a+i4,a+i4+2) } B2V(x,a) B2V(x,a+2) } +#define B8V(x,a) { for (int i8=0;i8<4;i8++) { ORDERV(x,a+i8,a+i8+4) } B4V(x,a) B4V(x,a+4) } +#define B16V(x,a) { for (int i16=0;i16<8;i16++) { ORDERV(x,a+i16,a+i16+8) } B8V(x,a) B8V(x,a+8) } +% endif +#define nsize ${nsize} //Total next dimensions sizes sum. (Block size) +#define dsize ${dsize} //Dimension size +""" + +ParallelBitonic_B2 = """ +// N/2 threads +//ParallelBitonic_B2 +__kernel void run(__global data_t * data\\ +% if argsort: +, __global idx_t * index) +% else: +) +% endif +{ + int t = get_global_id(0) % (dsize>>1); // thread index + int gt = get_global_id(0) / (dsize>>1); + int low = t & (inc - 1); // low order bits (below INC) + int i = (t<<1) - low; // insert 0 at position INC + int gi = i/dsize; // block index + bool reverse = ((dir & i) == 0);// ^ (gi%2); // asc/desc order + + int offset = (gt/nsize)*nsize*dsize+(gt%nsize); + data += i*nsize + offset; // translate to first value +% if argsort: + index += i*nsize + offset; // translate to first value +% endif + + // Load data + data_t x0 = data[ 0]; + data_t x1 = data[inc*nsize]; +% if argsort: + // Load index + idx_t i0 = index[ 0]; + idx_t i1 = index[inc*nsize]; +% endif + + // Sort +% if argsort: + ORDER(x0,x1,i0,i1) +% else: + ORDER(x0,x1) +% endif + + // Store data + data[0 ] = x0; + data[inc*nsize] = x1; +% if argsort: + // Store index + index[ 0] = i0; + index[inc*nsize] = i1; +% endif +} +""" + +ParallelBitonic_B4 = """ +// N/4 threads +//ParallelBitonic_B4 +__kernel void run(__global data_t * data\\ +% if argsort: +, __global idx_t * index) +% else: +) +% endif +{ + int t = get_global_id(0) % (dsize>>2); // thread index + int gt = get_global_id(0) / (dsize>>2); + int low = t & (hinc - 1); // low order bits (below INC) + int i = ((t - low) << 2) + low; // insert 00 at position INC + bool reverse = ((dir & i) == 0); // asc/desc order + int offset = (gt/nsize)*nsize*dsize+(gt%nsize); + data += i*nsize + offset; // translate to first value +% if argsort: + index += i*nsize + offset; // translate to first value +% endif + + // Load data + data_t x0 = data[ 0]; + data_t x1 = data[ hinc*nsize]; + data_t x2 = data[2*hinc*nsize]; + data_t x3 = data[3*hinc*nsize]; +% if argsort: + // Load index + idx_t i0 = index[ 0]; + idx_t i1 = index[ hinc*nsize]; + idx_t i2 = index[2*hinc*nsize]; + idx_t i3 = index[3*hinc*nsize]; +% endif + + // Sort +% if argsort: + ORDER(x0,x2,i0,i2) + ORDER(x1,x3,i1,i3) + ORDER(x0,x1,i0,i1) + ORDER(x2,x3,i2,i3) +% else: + ORDER(x0,x2) + ORDER(x1,x3) + ORDER(x0,x1) + ORDER(x2,x3) +% endif + + // Store data + data[ 0] = x0; + data[ hinc*nsize] = x1; + data[2*hinc*nsize] = x2; + data[3*hinc*nsize] = x3; +% if argsort: + // Store index + index[ 0] = i0; + index[ hinc*nsize] = i1; + index[2*hinc*nsize] = i2; + index[3*hinc*nsize] = i3; +% endif +} +""" + +ParallelBitonic_B8 = """ +// N/8 threads +//ParallelBitonic_B8 +__kernel void run(__global data_t * data\\ +% if argsort: +, __global idx_t * index) +% else: +) +% endif +{ + int t = get_global_id(0) % (dsize>>3); // thread index + int gt = get_global_id(0) / (dsize>>3); + int low = t & (qinc - 1); // low order bits (below INC) + int i = ((t - low) << 3) + low; // insert 000 at position INC + bool reverse = ((dir & i) == 0); // asc/desc order + int offset = (gt/nsize)*nsize*dsize+(gt%nsize); + + data += i*nsize + offset; // translate to first value +% if argsort: + index += i*nsize + offset; // translate to first value +% endif + + // Load + data_t x[8]; +% if argsort: + idx_t y[8]; +% endif + for (int k=0;k<8;k++) x[k] = data[k*qinc*nsize]; +% if argsort: + for (int k=0;k<8;k++) y[k] = index[k*qinc*nsize]; +% endif + + // Sort +% if argsort: + B8V(x,y,0) +% else: + B8V(x,0) +% endif + + // Store + for (int k=0;k<8;k++) data[k*qinc*nsize] = x[k]; +% if argsort: + for (int k=0;k<8;k++) index[k*qinc*nsize] = y[k]; +% endif +} +""" + +ParallelBitonic_B16 = """ +// N/16 threads +//ParallelBitonic_B16 +__kernel void run(__global data_t * data\\ +% if argsort: +, __global idx_t * index) +% else: +) +% endif +{ + int t = get_global_id(0) % (dsize>>4); // thread index + int gt = get_global_id(0) / (dsize>>4); + int low = t & (einc - 1); // low order bits (below INC) + int i = ((t - low) << 4) + low; // insert 0000 at position INC + bool reverse = ((dir & i) == 0); // asc/desc order + int offset = (gt/nsize)*nsize*dsize+(gt%nsize); + + data += i*nsize + offset; // translate to first value +% if argsort: + index += i*nsize + offset; // translate to first value +% endif + + // Load + data_t x[16]; +% if argsort: + idx_t y[16]; +% endif + for (int k=0;k<16;k++) x[k] = data[k*einc*nsize]; +% if argsort: + for (int k=0;k<16;k++) y[k] = index[k*einc*nsize]; +% endif + + // Sort +% if argsort: + B16V(x,y,0) +% else: + B16V(x,0) +% endif + + // Store + for (int k=0;k<16;k++) data[k*einc*nsize] = x[k]; +% if argsort: + for (int k=0;k<16;k++) index[k*einc*nsize] = y[k]; +% endif +} +""" + +ParallelBitonic_C4 = """ +//ParallelBitonic_C4 +__kernel void run\\ +% if argsort: +(__global data_t * data, __global idx_t * index, __local data_t * aux, __local idx_t * auy) +% else: +(__global data_t * data, __local data_t * aux) +% endif +{ + int t = get_global_id(0); // thread index + int wgBits = 4*get_local_size(0) - 1; // bit mask to get index in local memory AUX (size is 4*WG) + int linc,low,i; + bool reverse; + data_t x[4]; +% if argsort: + idx_t y[4]; +% endif + + // First iteration, global input, local output + linc = hinc; + low = t & (linc - 1); // low order bits (below INC) + i = ((t - low) << 2) + low; // insert 00 at position INC + reverse = ((dir & i) == 0); // asc/desc order + for (int k=0;k<4;k++) x[k] = data[i+k*linc]; +% if argsort: + for (int k=0;k<4;k++) y[k] = index[i+k*linc]; + B4V(x,y,0); + for (int k=0;k<4;k++) auy[(i+k*linc) & wgBits] = y[k]; +% else: + B4V(x,0); +% endif + for (int k=0;k<4;k++) aux[(i+k*linc) & wgBits] = x[k]; + barrier(CLK_LOCAL_MEM_FENCE); + + // Internal iterations, local input and output + for ( ;linc>1;linc>>=2) + { + low = t & (linc - 1); // low order bits (below INC) + i = ((t - low) << 2) + low; // insert 00 at position INC + reverse = ((dir & i) == 0); // asc/desc order + for (int k=0;k<4;k++) x[k] = aux[(i+k*linc) & wgBits]; +% if argsort: + for (int k=0;k<4;k++) y[k] = auy[(i+k*linc) & wgBits]; + B4V(x,y,0); + barrier(CLK_LOCAL_MEM_FENCE); + for (int k=0;k<4;k++) auy[(i+k*linc) & wgBits] = y[k]; +% else: + B4V(x,0); + barrier(CLK_LOCAL_MEM_FENCE); +% endif + for (int k=0;k<4;k++) aux[(i+k*linc) & wgBits] = x[k]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + // Final iteration, local input, global output, INC=1 + i = t << 2; + reverse = ((dir & i) == 0); // asc/desc order + for (int k=0;k<4;k++) x[k] = aux[(i+k) & wgBits]; +% if argsort: + for (int k=0;k<4;k++) y[k] = auy[(i+k) & wgBits]; + B4V(x,y,0); + for (int k=0;k<4;k++) index[i+k] = y[k]; +% else: + B4V(x,0); +% endif + for (int k=0;k<4;k++) data[i+k] = x[k]; +} +""" + + +ParallelMerge_Local = """ +// N threads, WG is workgroup size. Sort WG input blocks in each workgroup. +__kernel void run(__global const data_t * in,__global data_t * out,__local data_t * aux) +{ + int i = get_local_id(0); // index in workgroup + int wg = get_local_size(0); // workgroup size = block size, power of 2 + + // Move IN, OUT to block start + int offset = get_group_id(0) * wg; + in += offset; out += offset; + + // Load block in AUX[WG] + aux[i] = in[i]; + barrier(CLK_LOCAL_MEM_FENCE); // make sure AUX is entirely up to date + + // Now we will merge sub-sequences of length 1,2,...,WG/2 + for (int length=1;length<wg;length<<=1) + { + data_t iData = aux[i]; + data_t iKey = getKey(iData); + int ii = i & (length-1); // index in our sequence in 0..length-1 + int sibling = (i - ii) ^ length; // beginning of the sibling sequence + int pos = 0; + for (int pinc=length;pinc>0;pinc>>=1) // increment for dichotomic search + { + int j = sibling+pos+pinc-1; + data_t jKey = getKey(aux[j]); + bool smaller = (jKey < iKey) || ( jKey == iKey && j < i ); + pos += (smaller)?pinc:0; + pos = min(pos,length); + } + int bits = 2*length-1; // mask for destination + int dest = ((ii + pos) & bits) | (i & ~bits); // destination index in merged sequence + barrier(CLK_LOCAL_MEM_FENCE); + aux[dest] = iData; + barrier(CLK_LOCAL_MEM_FENCE); + } + + // Write output + out[i] = aux[i]; +} +""" + +ParallelBitonic_Local = """ +// N threads, WG is workgroup size. Sort WG input blocks in each workgroup. +__kernel void run(__global const data_t * in,__global data_t * out,__local data_t * aux) +{ + int i = get_local_id(0); // index in workgroup + int wg = get_local_size(0); // workgroup size = block size, power of 2 + + // Move IN, OUT to block start + int offset = get_group_id(0) * wg; + in += offset; out += offset; + + // Load block in AUX[WG] + aux[i] = in[i]; + barrier(CLK_LOCAL_MEM_FENCE); // make sure AUX is entirely up to date + + // Loop on sorted sequence length + for (int length=1;length<wg;length<<=1) + { + bool direction = ((i & (length<<1)) != 0); // direction of sort: 0=asc, 1=desc + // Loop on comparison distance (between keys) + for (int pinc=length;pinc>0;pinc>>=1) + { + int j = i + pinc; // sibling to compare + data_t iData = aux[i]; + uint iKey = getKey(iData); + data_t jData = aux[j]; + uint jKey = getKey(jData); + bool smaller = (jKey < iKey) || ( jKey == iKey && j < i ); + bool swap = smaller ^ (j < i) ^ direction; + barrier(CLK_LOCAL_MEM_FENCE); + aux[i] = (swap)?jData:iData; + barrier(CLK_LOCAL_MEM_FENCE); + } + } + + // Write output + out[i] = aux[i]; +} +""" + +ParallelBitonic_A = """ +__kernel void ParallelBitonic_A(__global const data_t * in) +{ + int i = get_global_id(0); // thread index + int j = i ^ inc; // sibling to compare + + // Load values at I and J + data_t iData = in[i]; + uint iKey = getKey(iData); + data_t jData = in[j]; + uint jKey = getKey(jData); + + // Compare + bool smaller = (jKey < iKey) || ( jKey == iKey && j < i ); + bool swap = smaller ^ (j < i) ^ ((dir & i) != 0); + + // Store + in[i] = (swap)?jData:iData; +} +""" + +ParallelBitonic_Local_Optim = """ +__kernel void run\\ +% if argsort: +(__global data_t * data, __global idx_t * index, __local data_t * aux, __local idx_t * auy) +% else: +(__global data_t * data, __local data_t * aux) +% endif +{ + int t = get_global_id(0) % dsize; // thread index + int gt = get_global_id(0) / dsize; + int offset = (gt/nsize)*nsize*dsize+(gt%nsize); + + int i = get_local_id(0); // index in workgroup + int wg = get_local_size(0); // workgroup size = block size, power of 2 + + // Move IN, OUT to block start + //int offset = get_group_id(0) * wg; + data += offset; + // Load block in AUX[WG] + data_t iData = data[t*nsize]; + aux[i] = iData; +% if argsort: + index += offset; + // Load block in AUY[WG] + idx_t iidx = index[t*nsize]; + auy[i] = iidx; +% endif + barrier(CLK_LOCAL_MEM_FENCE); // make sure AUX is entirely up to date + + // Loop on sorted sequence length + for (int pwg=1;pwg<=wg;pwg<<=1){ + int loffset = pwg*(i/pwg); + int ii = i%pwg; + for (int length=1;length<pwg;length<<=1){ + bool direction = ii & (length<<1); // direction of sort: 0=asc, 1=desc + // Loop on comparison distance (between keys) + for (int pinc=length;pinc>0;pinc>>=1){ + int j = ii ^ pinc; // sibling to compare + data_t jData = aux[loffset+j]; +% if argsort: + idx_t jidx = auy[loffset+j]; +% endif + data_t iKey = getKey(iData); + data_t jKey = getKey(jData); + bool smaller = (jKey < iKey) || ( jKey == iKey && j < ii ); + bool swap = smaller ^ (ii>j) ^ direction; + iData = (swap)?jData:iData; // update iData +% if argsort: + iidx = (swap)?jidx:iidx; // update iidx +% endif + barrier(CLK_LOCAL_MEM_FENCE); + aux[loffset+ii] = iData; +% if argsort: + auy[loffset+ii] = iidx; +% endif + barrier(CLK_LOCAL_MEM_FENCE); + } + } + } + + // Write output + data[t*nsize] = iData; +% if argsort: + index[t*nsize] = iidx; +% endif +} +""" \ No newline at end of file diff --git a/test/test_algorithm.py b/test/test_algorithm.py index b55c850e..3f135a40 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -1,10 +1,6 @@ #! /usr/bin/env python -from __future__ import division, with_statement -from __future__ import absolute_import -from __future__ import print_function -from six.moves import range -from six.moves import zip +from __future__ import division, with_statement, absolute_import, print_function __copyright__ = "Copyright (C) 2013 Andreas Kloeckner" @@ -28,6 +24,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from six.moves import range, zip import numpy as np import numpy.linalg as la import sys @@ -839,6 +836,40 @@ def test_key_value_sorter(ctx_factory): # }}} +# {{{ bitonic sort + +def test_bitonic_sort(ctx_factory): + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + import pyopencl.clrandom as clrandom + from pyopencl.bitonic_sort import BitonicSort + + s = clrandom.rand(queue, (4, 512, 5,), np.float32, luxury=None, a=0, b=1.0) + sorter = BitonicSort(ctx, s.shape, s.dtype, axis=1) + sgs = sorter(s) + assert np.array_equal(np.sort(s.get(), axis=1), sgs.get()) + + size = 2**18 + + index = cl_array.arange(queue, 0, size, 1, dtype=np.int32) + m = clrandom.rand(queue, (size,), np.float32, luxury=None, a=0, b=1.0) + + sorterm = BitonicSort(ctx, m.shape, m.dtype, idx_dtype=index.dtype, axis=0) + + ms = sorterm(m, idx=index) + + assert np.array_equal(np.sort(m.get()), ms.get()) + + # may be False because of identical values in array + # assert np.array_equal(np.argsort(m.get()), index.get()) + + # Check values by indices + assert np.array_equal(m.get()[np.argsort(m.get())], m.get()[index.get()]) + +# }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab