From 9e9d6acb1c44a81d128d9f6e2b698e4093f48a45 Mon Sep 17 00:00:00 2001 From: Martin Weigert <mweigert@mpi-cbg.de> Date: Wed, 13 Feb 2019 01:03:12 +0100 Subject: [PATCH] Fix ImportError --- pyopencl/array.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index 94f9e25a..046c841c 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -42,8 +42,7 @@ from pyopencl.compyte.array import ( c_contiguous_strides as _c_contiguous_strides, equal_strides as _equal_strides, ArrayFlags as _ArrayFlags, - get_common_dtype as _get_common_dtype_base, - get_truedivide_dtype as _get_truedivide_dtype_base) + get_common_dtype as _get_common_dtype_base) from pyopencl.characterize import has_double_support from pyopencl import cltypes @@ -53,9 +52,25 @@ def _get_common_dtype(obj1, obj2, queue): has_double_support(queue.device)) + def _get_truedivide_dtype(obj1, obj2, queue): - return _get_truedivide_dtype_base(obj1, obj2, - has_double_support(queue.device)) + # the dtype of the division result obj1 / obj2 + + allow_double = has_double_support(queue.device) + + x1 = obj1 if np.isscalar(obj1) else np.ones(1, obj1.dtype) + x2 = obj2 if np.isscalar(obj2) else np.ones(1, obj2.dtype) + + result = (x1/x2).dtype + + if not allow_double: + if result == np.float64: + result = np.dtype(np.float32) + elif result == np.complex128: + result = np.dtype(np.complex64) + + return result + # Work around PyPy not currently supporting the object dtype. -- GitLab