From 092db40d3fc6c0699645370eacea64228baa75ce Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 18 Nov 2011 04:17:13 -0500
Subject: [PATCH] Add struct reduce test.

---
 pyopencl/reduction.py |  3 +--
 test/test_array.py    | 62 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 63 insertions(+), 2 deletions(-)

diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py
index 6949320b..b4b1d5e4 100644
--- a/pyopencl/reduction.py
+++ b/pyopencl/reduction.py
@@ -55,11 +55,10 @@ KERNEL = """
         #pragma OPENCL EXTENSION cl_amd_fp64: enable
     % endif
 
+    ${preamble}
 
     typedef ${out_type} out_type;
 
-    ${preamble}
-
     __kernel void ${name}(
       __global out_type *out, ${arguments},
       unsigned int seq_count, unsigned int n)
diff --git a/test/test_array.py b/test/test_array.py
index 8a6d4213..dc67aa20 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -656,6 +656,66 @@ def test_view(ctx_factory):
     view = a_dev.view(np.int16)
     assert view.shape == (8, 32) and view.dtype == np.int16
 
+@pytools.test.mark_test.opencl
+def test_struct_reduce(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    preamble = """//CL//
+    struct minmax_collector
+    {
+        float cur_min;
+        float cur_max;
+    };
+
+    typedef struct minmax_collector minmax_collector;
+
+    minmax_collector mmc_neutral()
+    {
+        // FIXME: needs infinity literal in real use, ok here
+        minmax_collector result = {10000, -10000};
+        return result;
+    }
+
+    minmax_collector mmc_from_scalar(float x)
+    {
+        minmax_collector result = {x, x};
+        return result;
+    }
+
+    minmax_collector agg_mmc(minmax_collector a, minmax_collector b)
+    {
+        minmax_collector result = {
+            fmin(a.cur_min, b.cur_min),
+            fmax(a.cur_max, b.cur_max),
+            };
+        return result;
+    }
+
+    """
+
+    mmc_dtype = np.dtype([("cur_min", np.float32), ("cur_max", np.float32)])
+
+    from pyopencl.clrandom import rand as clrand
+    a_gpu = clrand(queue, (20000,), dtype=np.float32)
+    a = a_gpu.get()
+
+    from pyopencl.tools import register_dtype
+    register_dtype(mmc_dtype, "minmax_collector")
+
+    from pyopencl.reduction import ReductionKernel
+    red = ReductionKernel(context, mmc_dtype,
+            neutral="mmc_neutral()",
+            reduce_expr="agg_mmc(a, b)", map_expr="mmc_from_scalar(x[i])",
+            arguments="__global float *x", preamble=preamble)
+
+    minmax = red(a_gpu).get()
+    #print minmax["cur_min"], minmax["cur_max"]
+    #print np.min(a), np.max(a)
+
+    assert minmax["cur_min"] == np.min(a)
+    assert minmax["cur_max"] == np.max(a)
+
 
 
 
@@ -670,3 +730,5 @@ if __name__ == "__main__":
     else:
         from py.test.cmdline import main
         main([__file__])
+
+# vim: filetype=pyopencl
-- 
GitLab