From c11e5576cbc911e24e73bad5093c8adf73a2c6a4 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992@gmail.com>
Date: Sun, 18 May 2014 11:58:22 -0400
Subject: [PATCH] fix compile error for binary float function with two
 different input types, add tests for atan2 atan2pi

---
 pyopencl/clmath.py      |  8 ++++----
 pyopencl/elementwise.py | 25 +++++++++++++++++++++++--
 test/test_clmath.py     | 34 ++++++++++++++++++++++++++++++++++
 3 files changed, 61 insertions(+), 6 deletions(-)

diff --git a/pyopencl/clmath.py b/pyopencl/clmath.py
index 12898529..1b41ce67 100644
--- a/pyopencl/clmath.py
+++ b/pyopencl/clmath.py
@@ -57,14 +57,14 @@ asinpi = _make_unary_array_func("asinpi")
 
 @cl_array.elwise_kernel_runner
 def _atan2(result, arg1, arg2):
-    return elementwise.get_binary_func_kernel(result.context, "atan2",
-            arg1.dtype, arg2.dtype, result.dtype)
+    return elementwise.get_float_binary_func_kernel(
+        result.context, "atan2", arg1.dtype, arg2.dtype, result.dtype)
 
 
 @cl_array.elwise_kernel_runner
 def _atan2pi(result, arg1, arg2):
-    return elementwise.get_binary_func_kernel(result.context, "atan2pi",
-            arg1.dtype, arg2.dtype, result.dtype)
+    return elementwise.get_float_binary_func_kernel(
+        result.context, "atan2pi", arg1.dtype, arg2.dtype, result.dtype)
 
 
 atan = _make_unary_array_func("atan")
diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 8d583681..03c37024 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -837,11 +837,32 @@ def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype,
         preamble=preamble)
 
 
+@context_dependent_memoize
+def get_float_binary_func_kernel(context, func_name, x_dtype, y_dtype,
+                                 out_dtype, preamble="", name=None):
+    if (np.array(0, x_dtype) * np.array(0, y_dtype)).itemsize > 4:
+        arg_type = 'double'
+        preamble = """
+        #pragma OPENCL EXTENSION cl_khr_fp64: enable
+        #define PYOPENCL_DEFINE_CDOUBLE
+        """ + preamble
+    else:
+        arg_type = 'float'
+    return get_elwise_kernel(context, [
+        VectorArg(out_dtype, "z", with_offset=True),
+        VectorArg(x_dtype, "x", with_offset=True),
+        VectorArg(y_dtype, "y", with_offset=True),
+        ],
+        "z[i] = %s((%s)x[i], (%s)y[i])" % (func_name, arg_type, arg_type),
+        name="%s_kernel" % func_name if name is None else name,
+        preamble=preamble)
+
+
 @context_dependent_memoize
 def get_fmod_kernel(context, out_dtype=np.float32, arg_dtype=np.float32,
                     mod_dtype=np.float32):
-    return get_binary_func_kernel(context, 'fmod', arg_dtype,
-                                  mod_dtype, out_dtype)
+    return get_float_binary_func_kernel(context, 'fmod', arg_dtype,
+                                        mod_dtype, out_dtype)
 
 
 @context_dependent_memoize
diff --git a/test/test_clmath.py b/test/test_clmath.py
index 3091e942..6ebbe46b 100644
--- a/test/test_clmath.py
+++ b/test/test_clmath.py
@@ -127,6 +127,40 @@ if have_cl():
     test_tanh = make_unary_function_test("tanh", (-3, 3), 2e-6, use_complex=True)
 
 
+def test_atan2(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    for s in sizes:
+        a = (cl_array.arange(queue, s, dtype=np.float32) - s / 2) / 100
+        a2 = (s / 2 - 1 - cl_array.arange(queue, s, dtype=np.float32)) / 100
+        b = clmath.atan2(a, a2)
+
+        a = a.get()
+        a2 = a2.get()
+        b = b.get()
+
+        for i in range(s):
+            assert abs(math.atan2(a[i], a2[i]) - b[i]) < 1e-6
+
+
+def test_atan2pi(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    for s in sizes:
+        a = (cl_array.arange(queue, s, dtype=np.float32) - s / 2) / 100
+        a2 = (s / 2 - 1 - cl_array.arange(queue, s, dtype=np.float32)) / 100
+        b = clmath.atan2pi(a, a2)
+
+        a = a.get()
+        a2 = a2.get()
+        b = b.get()
+
+        for i in range(s):
+            assert abs(math.atan2(a[i], a2[i]) / math.pi - b[i]) < 1e-6
+
+
 def test_fmod(ctx_factory):
     context = ctx_factory()
     queue = cl.CommandQueue(context)
-- 
GitLab