diff --git a/pyopencl/array.py b/pyopencl/array.py index 03b87209ee715b684be44414f3acea550234a9a0..33ab96832b949f4d0735a32d970f037c74780e6f 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -227,6 +227,7 @@ class _copy_queue: # noqa _ARRAY_GET_SIZES_CACHE = {} +_BOOL_DTYPE = np.dtype(np.int8) class Array: @@ -1482,71 +1483,71 @@ class Array: def __eq__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="==")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="==")) return result def __ne__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="!=")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="!=")) return result def __le__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="<=")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) self._scalar_comparison(result, self, other, op="<=") return result def __ge__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op=">=")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op=">=")) return result def __lt__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op="<")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="<")) return result def __gt__(self, other): if isinstance(other, Array): - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._array_comparison(result, self, other, op=">")) return result else: - result = self._new_like_me(np.int8) + result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op=">")) return result diff --git a/test/test_array.py b/test/test_array.py index d8b0d7bd00fb95dafd2066ee26577b855edb7530..a6a91bef4e11089bb4223a5279420bcd155c6334 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1046,6 +1046,10 @@ def test_comparisons(ctx_factory): assert (res_dev.get() == res).all() + res2_dev = op(0, res_dev) + res2 = op(0, res) + assert (res2_dev.get() == res2).all() + def test_any_all(ctx_factory): context = ctx_factory()