diff --git a/pyopencl/array.py b/pyopencl/array.py index b6f52abe2e2e4df72d395beaaaf6c0527f0bcc9f..f3dd5e74d467ac175f2c8fcbfa448057ddd19807 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -32,7 +32,7 @@ import builtins from dataclasses import dataclass from functools import reduce from numbers import Number -from typing import Any, Dict, Hashable, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn import numpy as np @@ -58,20 +58,14 @@ else: _SVMPointer_or_nothing = () -_NUMPY_PRE_2 = np.__version__.startswith("1.") - - # {{{ _get_common_dtype -_COMMON_DTYPE_CACHE: Dict[Tuple[Hashable, ...], np.dtype] = {} - - class DoubleDowncastWarning(UserWarning): pass _DOUBLE_DOWNCAST_WARNING = ( - "The operation you requested would result in a double-precisision " + "The operation you requested would result in a double-precision " "quantity according to numpy semantics. Since your device does not " "support double precision, a single-precision quantity is being returned.") @@ -81,78 +75,9 @@ def _get_common_dtype(obj1, obj2, queue): raise ValueError("PyOpenCL array has no queue; call .with_queue() to " "add one in order to be able to perform operations") - allow_double = has_double_support(queue.device) - cache_key = None - o1_is_array = isinstance(obj1, Array) - o2_is_array = isinstance(obj2, Array) - - if o1_is_array and o2_is_array: - o1_dtype = obj1.dtype - o2_dtype = obj2.dtype - cache_key = (obj1.dtype, obj2.dtype, allow_double) - else: - o1_dtype = getattr(obj1, "dtype", type(obj1)) - o2_dtype = getattr(obj2, "dtype", type(obj2)) - - o1_is_integral = np.issubdtype(o1_dtype, np.integer) - o2_is_integral = np.issubdtype(o1_dtype, np.integer) - - o1_key = obj1 if o1_is_integral and not o1_is_array else o1_dtype - o2_key = obj2 if o2_is_integral and not o2_is_array else o2_dtype + result = np.result_type(obj1, obj2) - cache_key = (o1_key, o2_key, o1_is_array, o2_is_array, allow_double) - - try: - return _COMMON_DTYPE_CACHE[cache_key] - except KeyError: - pass - - # Numpy's behavior around integers is a bit bizarre, and definitely value- - # and not just type-sensitive when it comes to scalars. We'll just do our - # best to emulate it. - # - # Some samples that are true as of numpy 1.23.1. - # - # >>> a = np.zeros(1, dtype=np.int16) - # >>> (a + 123123123312).dtype - # dtype('int64') - # >>> (a + 12312).dtype - # dtype('int16') - # >>> (a + 12312444).dtype - # dtype('int32') - # >>> (a + np.int32(12312444)).dtype - # dtype('int32') - # >>> (a + np.int32(1234)).dtype - # dtype('int16') - # - # Note that np.find_common_type, while appealing, won't be able to tell - # the full story. - - if (_NUMPY_PRE_2 - and not (o1_is_array and o2_is_array) - and o1_is_integral and o2_is_integral): - if o1_is_array: - obj1 = np.zeros(1, dtype=o1_dtype) - if o2_is_array: - obj2 = np.zeros(1, dtype=o2_dtype) - - result = (obj1 + obj2).dtype - else: - array_types = [] - scalars = [] - - if o1_is_array: - array_types.append(o1_dtype) - else: - scalars.append(obj1) - if o2_is_array: - array_types.append(o2_dtype) - else: - scalars.append(obj2) - - result = np.result_type(*array_types, *scalars) - - if not allow_double: + if not has_double_support(queue.device): if result == np.float64: result = np.dtype(np.float32) warn(_DOUBLE_DOWNCAST_WARNING, DoubleDowncastWarning, stacklevel=3) @@ -160,9 +85,6 @@ def _get_common_dtype(obj1, obj2, queue): result = np.dtype(np.complex64) warn(_DOUBLE_DOWNCAST_WARNING, DoubleDowncastWarning, stacklevel=3) - if cache_key is not None: - _COMMON_DTYPE_CACHE[cache_key] = result - return result # }}}