diff --git a/pyopencl/array.py b/pyopencl/array.py index 699466aa4b4055a1c5a3f8738200839742f8e7bf..3e7721e8ec3a203509d58219e6a7b51d5967e78d 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -42,8 +42,7 @@ from pyopencl.compyte.array import ( f_contiguous_strides as _f_contiguous_strides, c_contiguous_strides as _c_contiguous_strides, equal_strides as _equal_strides, - ArrayFlags as _ArrayFlags, - get_common_dtype as _get_common_dtype_base) + ArrayFlags as _ArrayFlags) from pyopencl.characterize import has_double_support from pyopencl import cltypes from numbers import Number @@ -59,35 +58,77 @@ else: _COMMON_DTYPE_CACHE = {} +class DoubleDowncastWarning(UserWarning): + pass + + +_DOUBLE_DOWNCAST_WARNING = ( + "The operation you requested would result in a double-precisision " + "quantity according to numpy semantics. Since your device does not " + "support double precision, a single-precision quantity is being returned.") + + def _get_common_dtype(obj1, obj2, queue): if queue is None: raise ValueError("PyOpenCL array has no queue; call .with_queue() to " "add one in order to be able to perform operations") - dsupport = has_double_support(queue.device) + allow_double = has_double_support(queue.device) cache_key = None - o1_dtype = obj1.dtype - try: - cache_key = (o1_dtype, obj2.dtype, dsupport) - return _COMMON_DTYPE_CACHE[cache_key] - except KeyError: - pass - except AttributeError: - # obj2 doesn't have a dtype - try: - tobj2 = type(obj2) - cache_key = (o1_dtype, tobj2, dsupport) + o1_dtype = getattr(obj1, "dtype", type(obj1)) + o2_dtype = getattr(obj2, "dtype", type(obj2)) + o1_is_array = isinstance(obj1, Array) + o2_is_array = isinstance(obj2, Array) + + if o1_dtype is int: + # must be an unsized scalar + o1_dtype = np.array(obj1).dtype + + # Er, what? Well, consider this, which is true as of numpy 1.23.1: + # + # np.find_common_type([int16], [int64]) == int16 + # + # apparently because of some hare-brained no-upcasts-because-of-scalars + # rule which even numpy itself ignores when determining result dtypes + # of an operation on integers. For non-integers OTOH, it seems to heed + # its own advice, but *shrug*. So we'll tell numpy that scalar integers + # "aren't scalars" to give their types "a bit more weight". *shudders* + o1_is_array = True + + if o2_dtype is int: + # must be an unsized scalar + o2_dtype = np.array(obj2).dtype + + # See above. + o2_is_array = True + + cache_key = (o1_dtype, o2_dtype, o1_is_array, o2_is_array, allow_double) + #try: + # return _COMMON_DTYPE_CACHE[cache_key] + #except KeyError: + # pass + + array_types = [] + scalar_types = [] + if o1_is_array: + array_types.append(o1_dtype) + else: + scalar_types.append(o1_dtype) + if o2_is_array: + array_types.append(o2_dtype) + else: + scalar_types.append(o2_dtype) - # Integers are weird, sized, and signed. Don't pretend that 'int' - # is enough information to decide what should happen. - if tobj2 != int: - return _COMMON_DTYPE_CACHE[cache_key] - except KeyError: - pass + result = np.find_common_type(array_types, scalar_types) - result = _get_common_dtype_base(obj1, obj2, dsupport) + if not allow_double: + if result == np.float64: + result = np.dtype(np.float32) + warn(_DOUBLE_DOWNCAST_WARNING, DoubleDowncastWarning, stacklevel=3) + elif result == np.complex128: + result = np.dtype(np.complex64) + warn(_DOUBLE_DOWNCAST_WARNING, DoubleDowncastWarning, stacklevel=3) - # we succeeded in constructing the cache key if cache_key is not None: _COMMON_DTYPE_CACHE[cache_key] = result diff --git a/test/test_array.py b/test/test_array.py index ffec2f4d21a48e917a5e2e63c78f86c6dfb4af2f..ca5113522bdcdeb82e75cea47c66e2070c51614c 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -93,6 +93,7 @@ def test_basic_complex(ctx_factory): ary = (rand(queue, shape=(size,), dtype=np.float32).astype(np.complex64) + rand(queue, shape=(size,), dtype=np.float32).astype(np.complex64) * 1j) + assert ary.dtype != np.dtype(np.complex128) c = np.complex64(5+7j) host_ary = ary.get()