Skip to content
Snippets Groups Projects
Commit 092db40d authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add struct reduce test.

parent 44602b21
No related branches found
No related tags found
No related merge requests found
......@@ -55,11 +55,10 @@ KERNEL = """
#pragma OPENCL EXTENSION cl_amd_fp64: enable
% endif
${preamble}
typedef ${out_type} out_type;
${preamble}
__kernel void ${name}(
__global out_type *out, ${arguments},
unsigned int seq_count, unsigned int n)
......
......@@ -656,6 +656,66 @@ def test_view(ctx_factory):
view = a_dev.view(np.int16)
assert view.shape == (8, 32) and view.dtype == np.int16
@pytools.test.mark_test.opencl
def test_struct_reduce(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)
preamble = """//CL//
struct minmax_collector
{
float cur_min;
float cur_max;
};
typedef struct minmax_collector minmax_collector;
minmax_collector mmc_neutral()
{
// FIXME: needs infinity literal in real use, ok here
minmax_collector result = {10000, -10000};
return result;
}
minmax_collector mmc_from_scalar(float x)
{
minmax_collector result = {x, x};
return result;
}
minmax_collector agg_mmc(minmax_collector a, minmax_collector b)
{
minmax_collector result = {
fmin(a.cur_min, b.cur_min),
fmax(a.cur_max, b.cur_max),
};
return result;
}
"""
mmc_dtype = np.dtype([("cur_min", np.float32), ("cur_max", np.float32)])
from pyopencl.clrandom import rand as clrand
a_gpu = clrand(queue, (20000,), dtype=np.float32)
a = a_gpu.get()
from pyopencl.tools import register_dtype
register_dtype(mmc_dtype, "minmax_collector")
from pyopencl.reduction import ReductionKernel
red = ReductionKernel(context, mmc_dtype,
neutral="mmc_neutral()",
reduce_expr="agg_mmc(a, b)", map_expr="mmc_from_scalar(x[i])",
arguments="__global float *x", preamble=preamble)
minmax = red(a_gpu).get()
#print minmax["cur_min"], minmax["cur_max"]
#print np.min(a), np.max(a)
assert minmax["cur_min"] == np.min(a)
assert minmax["cur_max"] == np.max(a)
......@@ -670,3 +730,5 @@ if __name__ == "__main__":
else:
from py.test.cmdline import main
main([__file__])
# vim: filetype=pyopencl
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment