From edfbf18b9d5c01393bfa651bc7b3cc941891a124 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 17 May 2021 17:10:05 -0500
Subject: [PATCH] Fix, test comparison on np.int8

---
 pyopencl/array.py  | 25 +++++++++++++------------
 test/test_array.py |  4 ++++
 2 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index 03b87209..33ab9683 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -227,6 +227,7 @@ class _copy_queue:  # noqa
 
 
 _ARRAY_GET_SIZES_CACHE = {}
+_BOOL_DTYPE = np.dtype(np.int8)
 
 
 class Array:
@@ -1482,71 +1483,71 @@ class Array:
 
     def __eq__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._array_comparison(result, self, other, op="=="))
             return result
         else:
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._scalar_comparison(result, self, other, op="=="))
             return result
 
     def __ne__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._array_comparison(result, self, other, op="!="))
             return result
         else:
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._scalar_comparison(result, self, other, op="!="))
             return result
 
     def __le__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._array_comparison(result, self, other, op="<="))
             return result
         else:
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             self._scalar_comparison(result, self, other, op="<=")
             return result
 
     def __ge__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._array_comparison(result, self, other, op=">="))
             return result
         else:
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._scalar_comparison(result, self, other, op=">="))
             return result
 
     def __lt__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._array_comparison(result, self, other, op="<"))
             return result
         else:
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._scalar_comparison(result, self, other, op="<"))
             return result
 
     def __gt__(self, other):
         if isinstance(other, Array):
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._array_comparison(result, self, other, op=">"))
             return result
         else:
-            result = self._new_like_me(np.int8)
+            result = self._new_like_me(_BOOL_DTYPE)
             result.add_event(
                     self._scalar_comparison(result, self, other, op=">"))
             return result
diff --git a/test/test_array.py b/test/test_array.py
index d8b0d7bd..a6a91bef 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -1046,6 +1046,10 @@ def test_comparisons(ctx_factory):
 
         assert (res_dev.get() == res).all()
 
+        res2_dev = op(0, res_dev)
+        res2 = op(0, res)
+        assert (res2_dev.get() == res2).all()
+
 
 def test_any_all(ctx_factory):
     context = ctx_factory()
-- 
GitLab