diff --git a/pyopencl/array.py b/pyopencl/array.py index 3e7721e8ec3a203509d58219e6a7b51d5967e78d..fd0fedc09c290fecc335cbc63f81b01d6bd6c781 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -75,51 +75,72 @@ def _get_common_dtype(obj1, obj2, queue): allow_double = has_double_support(queue.device) cache_key = None - o1_dtype = getattr(obj1, "dtype", type(obj1)) - o2_dtype = getattr(obj2, "dtype", type(obj2)) o1_is_array = isinstance(obj1, Array) o2_is_array = isinstance(obj2, Array) - if o1_dtype is int: - # must be an unsized scalar - o1_dtype = np.array(obj1).dtype - - # Er, what? Well, consider this, which is true as of numpy 1.23.1: - # - # np.find_common_type([int16], [int64]) == int16 - # - # apparently because of some hare-brained no-upcasts-because-of-scalars - # rule which even numpy itself ignores when determining result dtypes - # of an operation on integers. For non-integers OTOH, it seems to heed - # its own advice, but *shrug*. So we'll tell numpy that scalar integers - # "aren't scalars" to give their types "a bit more weight". *shudders* - o1_is_array = True - - if o2_dtype is int: - # must be an unsized scalar - o2_dtype = np.array(obj2).dtype - - # See above. - o2_is_array = True - - cache_key = (o1_dtype, o2_dtype, o1_is_array, o2_is_array, allow_double) - #try: - # return _COMMON_DTYPE_CACHE[cache_key] - #except KeyError: - # pass - - array_types = [] - scalar_types = [] - if o1_is_array: - array_types.append(o1_dtype) + if o1_is_array and o2_is_array: + o1_dtype = obj1.dtype + o2_dtype = obj2.dtype + cache_key = (obj1.dtype, obj2.dtype, allow_double) else: - scalar_types.append(o1_dtype) - if o2_is_array: - array_types.append(o2_dtype) + o1_dtype = getattr(obj1, "dtype", type(obj1)) + o2_dtype = getattr(obj2, "dtype", type(obj2)) + + o1_is_integral = np.issubdtype(o1_dtype, np.integer) + o2_is_integral = np.issubdtype(o1_dtype, np.integer) + + o1_key = obj1 if o1_is_integral and not o1_is_array else o1_dtype + o2_key = obj2 if o2_is_integral and not o2_is_array else o2_dtype + + cache_key = (o1_key, o2_key, o1_is_array, o2_is_array, allow_double) + + try: + return _COMMON_DTYPE_CACHE[cache_key] + except KeyError: + pass + + # Numpy's behavior around integers is a bit bizarre, and definitely value- + # and not just type-sensitive when it comes to scalars. We'll just do our + # best to emulate it. + # + # Some samples that are true as of numpy 1.23.1. + # + # >>> a = np.zeros(1, dtype=np.int16) + # >>> (a + 123123123312).dtype + # dtype('int64') + # >>> (a + 12312).dtype + # dtype('int16') + # >>> (a + 12312444).dtype + # dtype('int32') + # >>> (a + np.int32(12312444)).dtype + # dtype('int32') + # >>> (a + np.int32(1234)).dtype + # dtype('int16') + # + # Note that np.find_common_type, while appealing, won't be able to tell + # the full story. + + if not (o1_is_array and o2_is_array) and o1_is_integral and o2_is_integral: + if o1_is_array: + obj1 = np.zeros(1, dtype=o1_dtype) + if o2_is_array: + obj2 = np.zeros(1, dtype=o2_dtype) + + result = (obj1 + obj2).dtype else: - scalar_types.append(o2_dtype) + array_types = [] + scalar_types = [] + + if o1_is_array: + array_types.append(o1_dtype) + else: + scalar_types.append(o1_dtype) + if o2_is_array: + array_types.append(o2_dtype) + else: + scalar_types.append(o2_dtype) - result = np.find_common_type(array_types, scalar_types) + result = np.find_common_type(array_types, scalar_types) if not allow_double: if result == np.float64: