From 0671b822fa9c5619da5f81d107d12e390c9ce0d2 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 13 Jul 2012 11:28:36 -0500
Subject: [PATCH] Support complex numbers in reductions, sum, dot.

---
 pyopencl/reduction.py | 25 +++++++++++++++++++++++--
 test/test_array.py    | 42 ++++++++++++++++++++++++++++--------------
 2 files changed, 51 insertions(+), 16 deletions(-)

diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py
index 3085587c..25acbc9f 100644
--- a/pyopencl/reduction.py
+++ b/pyopencl/reduction.py
@@ -44,17 +44,20 @@ import pyopencl._mymako as mako
 
 
 KERNEL = """
-
     #define GROUP_SIZE ${group_size}
     #define READ_AND_MAP(i) (${map_expr})
     #define REDUCE(a, b) (${reduce_expr})
 
     % if double_support:
         #pragma OPENCL EXTENSION cl_khr_fp64: enable
+        #define PYOPENCL_DEFINE_CDOUBLE
     % elif amd_double_support:
         #pragma OPENCL EXTENSION cl_amd_fp64: enable
+        #define PYOPENCL_DEFINE_CDOUBLE
     % endif
 
+    #include <pyopencl-complex.h>
+
     ${preamble}
 
     typedef ${out_type} out_type;
@@ -402,8 +405,26 @@ def get_dot_kernel(ctx, dtype_out, dtype_a=None, dtype_b=None):
     if dtype_a is None:
         dtype_a = dtype_out
 
+    a_is_complex = dtype_a.kind == "c"
+    b_is_complex = dtype_b.kind == "c"
+    out_is_complex = dtype_out.kind == "c"
+
+    if out_is_complex:
+        a = "a[i]"
+        b = "b[i]"
+        from pyopencl.elementwise import complex_dtype_to_name
+        if a_is_complex and dtype_a != dtype_out:
+            a = "%s_cast(%s)" % (complex_dtype_to_name(dtype_out), a)
+        if b_is_complex and dtype_b != dtype_out:
+            b = "%s_cast(%s)" % (complex_dtype_to_name(dtype_out), b)
+
+        map_expr = "%s_mul(%s, %s)" % (
+                complex_dtype_to_name(dtype_out), a, b)
+    else:
+        map_expr = "a[i]*b[i]"
+
     return ReductionKernel(ctx, dtype_out, neutral="0",
-            reduce_expr="a+b", map_expr="a[i]*b[i]",
+            reduce_expr="a+b", map_expr=map_expr,
             arguments=
             "__global const %(tp_a)s *a, "
             "__global const %(tp_b)s *b" % {
diff --git a/test/test_array.py b/test/test_array.py
index 70b128f1..8e5a56ac 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -490,20 +490,34 @@ def test_reverse(ctx_factory):
     assert (a[::-1] == a_gpu.get()).all()
 
 
+def general_clrand(queue, shape, dtype):
+    from pyopencl.clrandom import rand as clrand
+
+    dtype = np.dtype(dtype)
+    if dtype.kind == "c":
+        real_dtype = dtype.type(0).real.dtype
+        return clrand(queue, shape, real_dtype) + 1j*clrand(queue, shape, real_dtype)
+    else:
+        return clrand(queue, shape, dtype)
+
+
+
+
 @pytools.test.mark_test.opencl
 def test_sum(ctx_factory):
     context = ctx_factory()
     queue = cl.CommandQueue(context)
 
-    from pyopencl.clrandom import rand as clrand
+    n = 200000
+    for dtype in [np.float32, np.complex64]:
+        a_gpu = general_clrand(queue, (n,), dtype)
 
-    a_gpu = clrand(queue, (200000,), np.float32)
-    a = a_gpu.get()
+        a = a_gpu.get()
 
-    sum_a = np.sum(a)
-    sum_a_gpu = cl_array.sum(a_gpu).get()
+        sum_a = np.sum(a)
+        sum_a_gpu = cl_array.sum(a_gpu).get()
 
-    assert abs(sum_a_gpu - sum_a) / abs(sum_a) < 1e-4
+        assert abs(sum_a_gpu - sum_a) / abs(sum_a) < 1e-4
 
 
 @pytools.test.mark_test.opencl
@@ -574,17 +588,17 @@ def test_dot(ctx_factory):
     context = ctx_factory()
     queue = cl.CommandQueue(context)
 
-    from pyopencl.clrandom import rand as clrand
-    a_gpu = clrand(queue, (200000,), np.float32)
-    a = a_gpu.get()
-    b_gpu = clrand(queue, (200000,), np.float32)
-    b = b_gpu.get()
+    for dtype in [np.float32, np.complex64]:
+        a_gpu = general_clrand(queue, (200000,), dtype)
+        a = a_gpu.get()
+        b_gpu = general_clrand(queue, (200000,), dtype)
+        b = b_gpu.get()
 
-    dot_ab = np.dot(a, b)
+        dot_ab = np.dot(a, b)
 
-    dot_ab_gpu = cl_array.dot(a_gpu, b_gpu).get()
+        dot_ab_gpu = cl_array.dot(a_gpu, b_gpu).get()
 
-    assert abs(dot_ab_gpu - dot_ab) / abs(dot_ab) < 1e-4
+        assert abs(dot_ab_gpu - dot_ab) / abs(dot_ab) < 1e-4
 
 
 if False:
-- 
GitLab