diff --git a/pyopencl/array.py b/pyopencl/array.py index f3dd5e74d467ac175f2c8fcbfa448057ddd19807..16e249456b4a34fc1aef147968bec2cc310d9515 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -75,6 +75,9 @@ def _get_common_dtype(obj1, obj2, queue): raise ValueError("PyOpenCL array has no queue; call .with_queue() to " "add one in order to be able to perform operations") + # Note: We are calling np.result_type with pyopencl arrays here. + # Luckily, np.result_type only looks at the dtype of input arrays as of + # numpy v2.1. result = np.result_type(obj1, obj2) if not has_double_support(queue.device): diff --git a/test/test_array.py b/test/test_array.py index 4b6488781998af60b7fb8397040c61f47fd804b2..f9ec27a837666269fcbf7e9a9807d73142cdcf77 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -2393,6 +2393,24 @@ def test_xdg_cache_home(ctx_factory): # }}} +def test_numpy_type_promotion_with_cl_arrays(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + class NotReallyAnArray: + @property + def dtype(self): + return np.dtype("float64") + + # Make sure that np.result_type accesses only the dtype attribute of the + # class, not (e.g.) its data. + assert np.result_type(42, NotReallyAnArray()) == np.float64 + + from pyopencl.array import _get_common_dtype + assert _get_common_dtype(42, NotReallyAnArray(), queue) == np.float64 + assert _get_common_dtype(42.0, NotReallyAnArray(), queue) == np.float64 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])