From edffdfdf3bebc4d0162acc35912e2adca2eea083 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 2 Oct 2024 16:34:29 -0500 Subject: [PATCH] add test, comment --- pyopencl/array.py | 3 +++ test/test_array.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/pyopencl/array.py b/pyopencl/array.py index f3dd5e74..16e24945 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -75,6 +75,9 @@ def _get_common_dtype(obj1, obj2, queue): raise ValueError("PyOpenCL array has no queue; call .with_queue() to " "add one in order to be able to perform operations") + # Note: We are calling np.result_type with pyopencl arrays here. + # Luckily, np.result_type only looks at the dtype of input arrays as of + # numpy v2.1. result = np.result_type(obj1, obj2) if not has_double_support(queue.device): diff --git a/test/test_array.py b/test/test_array.py index 4b648878..f9ec27a8 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -2393,6 +2393,24 @@ def test_xdg_cache_home(ctx_factory): # }}} +def test_numpy_type_promotion_with_cl_arrays(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + class NotReallyAnArray: + @property + def dtype(self): + return np.dtype("float64") + + # Make sure that np.result_type accesses only the dtype attribute of the + # class, not (e.g.) its data. + assert np.result_type(42, NotReallyAnArray()) == np.float64 + + from pyopencl.array import _get_common_dtype + assert _get_common_dtype(42, NotReallyAnArray(), queue) == np.float64 + assert _get_common_dtype(42.0, NotReallyAnArray(), queue) == np.float64 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab