From 9e9d6acb1c44a81d128d9f6e2b698e4093f48a45 Mon Sep 17 00:00:00 2001
From: Martin Weigert <mweigert@mpi-cbg.de>
Date: Wed, 13 Feb 2019 01:03:12 +0100
Subject: [PATCH] Fix ImportError

---
 pyopencl/array.py | 23 +++++++++++++++++++----
 1 file changed, 19 insertions(+), 4 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index 94f9e25a..046c841c 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -42,8 +42,7 @@ from pyopencl.compyte.array import (
         c_contiguous_strides as _c_contiguous_strides,
         equal_strides as _equal_strides,
         ArrayFlags as _ArrayFlags,
-        get_common_dtype as _get_common_dtype_base,
-        get_truedivide_dtype as _get_truedivide_dtype_base)
+        get_common_dtype as _get_common_dtype_base)
 from pyopencl.characterize import has_double_support
 from pyopencl import cltypes
 
@@ -53,9 +52,25 @@ def _get_common_dtype(obj1, obj2, queue):
                                   has_double_support(queue.device))
 
 
+
 def _get_truedivide_dtype(obj1, obj2, queue):
-    return _get_truedivide_dtype_base(obj1, obj2,
-                                  has_double_support(queue.device))
+    # the dtype of the division result obj1 / obj2
+
+    allow_double = has_double_support(queue.device)
+
+    x1 = obj1 if np.isscalar(obj1) else np.ones(1, obj1.dtype)
+    x2 = obj2 if np.isscalar(obj2) else np.ones(1, obj2.dtype)
+
+    result = (x1/x2).dtype
+
+    if not allow_double:
+        if result == np.float64:
+            result = np.dtype(np.float32)
+        elif result == np.complex128:
+            result = np.dtype(np.complex64)
+
+    return result
+
 
 
 # Work around PyPy not currently supporting the object dtype.
-- 
GitLab