From 2fe7d9f6303d7a152d04a7e257926979d7034da5 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 8 Mar 2012 14:12:05 -0500
Subject: [PATCH] Clean up dtype-related test failures.

---
 pyopencl/array.py                 | 67 ++++++++++++++++++++-----------
 pyopencl/characterize/__init__.py |  2 +
 pyopencl/compyte                  |  2 +-
 pyopencl/elementwise.py           |  8 +++-
 test/test_array.py                |  3 +-
 test/test_clmath.py               |  2 +-
 6 files changed, 56 insertions(+), 28 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index e608d196..8d84bb4a 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -39,7 +39,16 @@ from pyopencl.compyte.array import (
         f_contiguous_strides as _f_contiguous_strides,
         c_contiguous_strides as _c_contiguous_strides,
         ArrayFlags as _ArrayFlags,
-        get_common_dtype as _get_common_dtype)
+        get_common_dtype as _get_common_dtype_base)
+from pyopencl.characterize import has_double_support
+
+
+
+
+def _get_common_dtype(obj1, obj2, queue):
+    return _get_common_dtype_base(obj1, obj2,
+            has_double_support(queue.device))
+
 
 
 # {{{ vector types
@@ -516,7 +525,8 @@ class Array(object):
     def mul_add(self, selffac, other, otherfac, queue=None):
         """Return `selffac * self + otherfac*other`.
         """
-        result = self._new_like_me(_get_common_dtype(self, other))
+        result = self._new_like_me(
+                _get_common_dtype(self, other, queue or self.queue))
         self._axpbyz(result, selffac, self, otherfac, other)
         return result
 
@@ -525,7 +535,7 @@ class Array(object):
 
         if isinstance(other, Array):
             # add another vector
-            result = self._new_like_me(_get_common_dtype(self, other))
+            result = self._new_like_me(_get_common_dtype(self, other, self.queue))
             self._axpbyz(result,
                     self.dtype.type(1), self,
                     other.dtype.type(1), other)
@@ -535,8 +545,9 @@ class Array(object):
             if other == 0:
                 return self
             else:
-                result = self._new_like_me(_get_common_dtype(self, other))
-                self._axpbz(result, self.dtype.type(1), self, other)
+                common_dtype = _get_common_dtype(self, other, self.queue)
+                result = self._new_like_me(common_dtype)
+                self._axpbz(result, self.dtype.type(1), self, common_dtype.type(other))
                 return result
 
     __radd__ = __add__
@@ -545,7 +556,7 @@ class Array(object):
         """Substract an array from an array or a scalar from an array."""
 
         if isinstance(other, Array):
-            result = self._new_like_me(_get_common_dtype(self, other))
+            result = self._new_like_me(_get_common_dtype(self, other, self.queue))
             self._axpbyz(result,
                     self.dtype.type(1), self,
                     other.dtype.type(-1), other)
@@ -555,7 +566,7 @@ class Array(object):
             if other == 0:
                 return self
             else:
-                result = self._new_like_me(_get_common_dtype(self, other))
+                result = self._new_like_me(_get_common_dtype(self, other, self.queue))
                 self._axpbz(result, self.dtype.type(1), self, -other)
                 return result
 
@@ -564,9 +575,10 @@ class Array(object):
 
            x = n - self
         """
+        common_dtype = _get_common_dtype(self, other, self.queue)
         # other must be a scalar
-        result = self._new_like_me(_get_common_dtype(self, other))
-        self._axpbz(result, self.dtype.type(-1), self, other)
+        result = self._new_like_me(common_dtype)
+        self._axpbz(result, self.dtype.type(-1), self, common_dtype.type(other))
         return result
 
     def __iadd__(self, other):
@@ -594,17 +606,19 @@ class Array(object):
 
     def __mul__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(_get_common_dtype(self, other))
+            result = self._new_like_me(_get_common_dtype(self, other, self.queue))
             self._elwise_multiply(result, self, other)
             return result
         else:
-            result = self._new_like_me(_get_common_dtype(self, other))
-            self._axpbz(result, other, self, self.dtype.type(0))
+            common_dtype = _get_common_dtype(self, other, self.queue)
+            result = self._new_like_me(common_dtype)
+            self._axpbz(result, common_dtype.type(other), self, self.dtype.type(0))
             return result
 
     def __rmul__(self, scalar):
-        result = self._new_like_me(_get_common_dtype(self, scalar))
-        self._axpbz(result, scalar, self, self.dtype.type(0))
+        common_dtype = _get_common_dtype(self, scalar, self.queue)
+        result = self._new_like_me(common_dtype)
+        self._axpbz(result, common_dtype.type(scalar), self, self.dtype.type(0))
         return result
 
     def __imul__(self, scalar):
@@ -617,16 +631,17 @@ class Array(object):
            x = self / n
         """
         if isinstance(other, Array):
-            result = self._new_like_me(_get_common_dtype(self, other))
+            result = self._new_like_me(_get_common_dtype(self, other, self.queue))
             self._div(result, self, other)
         else:
             if other == 1:
                 return self
             else:
                 # create a new array for the result
-                result = self._new_like_me(_get_common_dtype(self, other))
+                common_dtype = _get_common_dtype(self, other, self.queue)
+                result = self._new_like_me(common_dtype)
                 self._axpbz(result,
-                        1/other, self, self.dtype.type(0))
+                        common_dtype.type(1/other), self, self.dtype.type(0))
 
         return result
 
@@ -639,15 +654,16 @@ class Array(object):
         """
 
         if isinstance(other, Array):
-            result = self._new_like_me(_get_common_dtype(self, other))
+            result = self._new_like_me(_get_common_dtype(self, other, self.queue))
             other._div(result, self)
         else:
             if other == 1:
                 return self
             else:
                 # create a new array for the result
-                result = self._new_like_me(_get_common_dtype(self, other))
-                self._rdiv_scalar(result, self, other)
+                common_dtype = _get_common_dtype(self, other, self.queue)
+                result = self._new_like_me(common_dtype)
+                self._rdiv_scalar(result, self, common_dtype.type(other))
 
         return result
 
@@ -682,18 +698,19 @@ class Array(object):
         if isinstance(other, Array):
             assert self.shape == other.shape
 
-            result = self._new_like_me(_get_common_dtype(self, other))
+            result = self._new_like_me(_get_common_dtype(self, other, self.queue))
             self._pow_array(result, self, other)
         else:
-            result = self._new_like_me(_get_common_dtype(self, other))
+            result = self._new_like_me(_get_common_dtype(self, other, self.queue))
             self._pow_scalar(result, self, other)
 
         return result
 
     def __rpow__(self, other):
         # other must be a scalar
-        result = self._new_like_me(_get_common_dtype(self, other))
-        self._rpow_scalar(result, other, self)
+        common_dtype = _get_common_dtype(self, other, self.queue)
+        result = self._new_like_me(common_dtype)
+        self._rpow_scalar(result, common_dtype.type(other), self)
         return result
 
     def reverse(self, queue=None):
@@ -795,6 +812,8 @@ class Array(object):
 
 # }}}
 
+# }}}
+
 # {{{ creation helpers
 
 def to_device(*args, **kwargs):
diff --git a/pyopencl/characterize/__init__.py b/pyopencl/characterize/__init__.py
index 858169aa..b6e92072 100644
--- a/pyopencl/characterize/__init__.py
+++ b/pyopencl/characterize/__init__.py
@@ -2,10 +2,12 @@ from __future__ import division
 
 import pyopencl as cl
 import numpy as np
+from pytools import memoize
 
 class CLCharacterizationWarning(UserWarning):
     pass
 
+@memoize
 def has_double_support(dev):
     for ext in dev.extensions.split(" "):
         if ext == "cl_khr_fp64":
diff --git a/pyopencl/compyte b/pyopencl/compyte
index 25865d27..389cf828 160000
--- a/pyopencl/compyte
+++ b/pyopencl/compyte
@@ -1 +1 @@
-Subproject commit 25865d27ab46752c93630d2b02e289fdf8080e43
+Subproject commit 389cf828b67bdddc83afed6d79bd448076432ec6
diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index a4ec223e..96f58b0a 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -432,8 +432,14 @@ def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z):
     x = "x[i]"
     y = "y[i]"
 
+    if x_is_complex and dtype_x != dtype_z:
+        x = "%s_cast(%s)" % (complex_dtype_to_name(dtype_z), x)
+    if y_is_complex and dtype_y != dtype_z:
+        y = "%s_cast(%s)" % (complex_dtype_to_name(dtype_z), y)
+
     if x_is_complex and y_is_complex:
         xy = "%s_mul(%s, %s)" % (complex_dtype_to_name(dtype_z), x, y)
+
     else:
         xy = "%s * %s" % (x, y)
 
@@ -503,7 +509,7 @@ def get_rdivide_elwise_kernel(context, dtype_x, dtype_y, dtype_z):
             y = "%s_cast(%s)" % (complex_dtype_to_name(dtype_z), y)
 
     if x_is_complex and y_is_complex:
-        yox = "%s_divide(%s, %s)" % (complex_dtype_to_name(dtype_z), y / x)
+        yox = "%s_divide(%s, %s)" % (complex_dtype_to_name(dtype_z), y, x)
     elif not y_is_complex and x_is_complex:
         yox = "%s_rdivide(%s, %s)" % (complex_dtype_to_name(dtype_x), y, x)
     else:
diff --git a/test/test_array.py b/test/test_array.py
index 78ea5535..c12ba43b 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -104,7 +104,8 @@ def test_basic_complex(ctx_factory):
             + 1j* rand(queue, shape=(size,), dtype=np.float32).astype(np.complex64))
     c = np.complex64(5+7j)
 
-    assert ((c*ary).get() == c*ary.get()).all()
+    host_ary = ary.get()
+    assert la.norm((c*ary).get() - c*host_ary) < 1e-5 * la.norm(host_ary)
 
 @pytools.test.mark_test.opencl
 def test_mix_complex(ctx_factory):
diff --git a/test/test_clmath.py b/test/test_clmath.py
index fb8bc770..9c14b91c 100644
--- a/test/test_clmath.py
+++ b/test/test_clmath.py
@@ -90,7 +90,7 @@ if have_cl():
     test_exp = make_unary_function_test("exp", (-3, 3), 1e-5, use_complex=True)
     test_log = make_unary_function_test("log", (1e-5, 1), 1e-6, use_complex=True)
     test_log10 = make_unary_function_test("log10", (1e-5, 1), 5e-7)
-    test_sqrt = make_unary_function_test("sqrt", (1e-5, 1), 2e-7, use_complex=True)
+    test_sqrt = make_unary_function_test("sqrt", (1e-5, 1), 3e-7, use_complex=True)
 
     test_sin = make_unary_function_test("sin", (-10, 10), 2e-7, use_complex=2e-3)
     test_cos = make_unary_function_test("cos", (-10, 10), 2e-7, use_complex=2e-3)
-- 
GitLab