From edfbf18b9d5c01393bfa651bc7b3cc941891a124 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 17 May 2021 17:10:05 -0500 Subject: [PATCH] Fix, test comparison on np.int8 --- pyopencl/array.py | 25 +++++++++++++------------ test/test_array.py | 4 ++++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index 03b87209..33ab9683 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -227,6 +227,7 @@ class _copy_queue: # noqa _ARRAY_GET_SIZES_CACHE = {} +_BOOL_DTYPE = np.dtype(np.int8) class Array: @@ -1482,71 +1483,71 @@ class Array: def __eq__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="==")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="==")) return result def __ne__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="!=")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="!=")) return result def __le__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="<=")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) self._scalar_comparison(result, self, other, op="<=") return result def __ge__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op=">=")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op=">=")) return result def __lt__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="<")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="<")) return result def __gt__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op=">")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op=">")) return result diff --git a/test/test_array.py b/test/test_array.py index d8b0d7bd..a6a91bef 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1046,6 +1046,10 @@ def test_comparisons(ctx_factory): assert (res_dev.get() == res).all() + res2_dev = op(0, res_dev) + res2 = op(0, res) + assert (res2_dev.get() == res2).all() + def test_any_all(ctx_factory): context = ctx_factory() -- GitLab