diff --git a/pyopencl/array.py b/pyopencl/array.py index ea80e09c27627d350a5ee88040c7d1bb4524b3ef..c0a0cdb391b43904c85a4da2bf0d1635c41223b4 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1256,13 +1256,16 @@ class Array: x = n - self """ - common_dtype = _get_common_dtype(self, other, self.queue) - # other must be a scalar - result = self._new_like_me(common_dtype) - result.add_event( - self._axpbz(result, result.dtype.type(-1), self, - common_dtype.type(other))) - return result + if np.isscalar(other): + common_dtype = _get_common_dtype(self, other, self.queue) + result = self._new_like_me(common_dtype) + result.add_event( + self._axpbz(result, result.dtype.type(-1), self, + common_dtype.type(other))) + + return result + else: + return NotImplemented def __iadd__(self, other): if isinstance(other, Array): @@ -1275,10 +1278,12 @@ class Array: other.dtype.type(1), other)) return self - else: + elif np.isscalar(other): self.add_event( self._axpbz(self, self.dtype.type(1), self, other)) return self + else: + return NotImplemented def __isub__(self, other): if isinstance(other, Array): @@ -1319,13 +1324,16 @@ class Array: else: return NotImplemented - def __rmul__(self, scalar): - common_dtype = _get_common_dtype(self, scalar, self.queue) - result = self._new_like_me(common_dtype) - result.add_event( - self._axpbz(result, - common_dtype.type(scalar), self, self.dtype.type(0))) - return result + def __rmul__(self, other): + if np.isscalar(other): + common_dtype = _get_common_dtype(self, other, self.queue) + result = self._new_like_me(common_dtype) + result.add_event( + self._axpbz(result, + common_dtype.type(other), self, self.dtype.type(0))) + return result + else: + return NotImplemented def __imul__(self, other): if isinstance(other, Array): @@ -1421,13 +1429,14 @@ class Array: if isinstance(other, Array): result = _get_broadcasted_binary_op_result(self, other, self.queue) result.add_event(self._array_binop(result, self, other, op="&")) - else: - # create a new array for the result + return result + elif np.isscalar(other): result = self._new_like_me(common_dtype) result.add_event( self._scalar_binop(result, self, other, op="&")) - - return result + return result + else: + return NotImplemented __rand__ = __and__ # commutes @@ -1441,13 +1450,14 @@ class Array: result = _get_broadcasted_binary_op_result(self, other, self.queue) result.add_event(self._array_binop(result, self, other, op="|")) - else: - # create a new array for the result + return result + elif np.isscalar(other): result = self._new_like_me(common_dtype) result.add_event( self._scalar_binop(result, self, other, op="|")) - - return result + return result + else: + return NotImplemented __ror__ = __or__ # commutes @@ -1460,13 +1470,14 @@ class Array: if isinstance(other, Array): result = _get_broadcasted_binary_op_result(self, other, self.queue) result.add_event(self._array_binop(result, self, other, op="^")) - else: - # create a new array for the result + return result + elif np.isscalar(other): result = self._new_like_me(common_dtype) result.add_event( self._scalar_binop(result, self, other, op="^")) - - return result + return result + else: + return NotImplemented __rxor__ = __xor__ # commutes @@ -1481,11 +1492,13 @@ class Array: raise NotImplementedError("Broadcasting binary op with shapes:" f" {self.shape}, {other.shape}.") self.add_event(self._array_binop(self, self, other, op="&")) - else: + return self + elif np.isscalar(other): self.add_event( self._scalar_binop(self, self, other, op="&")) - - return self + return self + else: + return NotImplemented def __ior__(self, other): common_dtype = _get_common_dtype(self, other, self.queue) @@ -1498,11 +1511,13 @@ class Array: raise NotImplementedError("Broadcasting binary op with shapes:" f" {self.shape}, {other.shape}.") self.add_event(self._array_binop(self, self, other, op="|")) - else: + return self + elif np.isscalar(other): self.add_event( self._scalar_binop(self, self, other, op="|")) - - return self + return self + else: + return NotImplemented def __ixor__(self, other): common_dtype = _get_common_dtype(self, other, self.queue) @@ -1515,11 +1530,13 @@ class Array: raise NotImplementedError("Broadcasting binary op with shapes:" f" {self.shape}, {other.shape}.") self.add_event(self._array_binop(self, self, other, op="^")) - else: + return self + elif np.isscalar(other): self.add_event( self._scalar_binop(self, self, other, op="^")) - - return self + return self + else: + return NotImplemented def _zero_fill(self, queue=None, wait_for=None): queue = queue or self.queue @@ -1582,20 +1599,24 @@ class Array: _get_common_dtype(self, other, self.queue)) result.add_event( self._pow_array(result, self, other)) - else: + return result + elif np.isscalar(other): result = self._new_like_me( _get_common_dtype(self, other, self.queue)) result.add_event(self._pow_scalar(result, self, other)) - - return result + return result + else: + return NotImplemented def __rpow__(self, other): - # other must be a scalar - common_dtype = _get_common_dtype(self, other, self.queue) - result = self._new_like_me(common_dtype) - result.add_event( - self._rpow_scalar(result, common_dtype.type(other), self)) - return result + if np.isscalar(other): + common_dtype = _get_common_dtype(self, other, self.queue) + result = self._new_like_me(common_dtype) + result.add_event( + self._rpow_scalar(result, common_dtype.type(other), self)) + return result + else: + return NotImplemented def __invert__(self): if not np.issubdtype(self.dtype, np.integer): @@ -1675,11 +1696,13 @@ class Array: result.add_event( self._array_comparison(result, self, other, op="==")) return result - else: + elif np.isscalar(other): result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="==")) return result + else: + return NotImplemented def __ne__(self, other): if isinstance(other, Array): @@ -1687,11 +1710,13 @@ class Array: result.add_event( self._array_comparison(result, self, other, op="!=")) return result - else: + elif np.isscalar(other): result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="!=")) return result + else: + return NotImplemented def __le__(self, other): if isinstance(other, Array): @@ -1699,10 +1724,12 @@ class Array: result.add_event( self._array_comparison(result, self, other, op="<=")) return result - else: + elif np.isscalar(other): result = self._new_like_me(_BOOL_DTYPE) self._scalar_comparison(result, self, other, op="<=") return result + else: + return NotImplemented def __ge__(self, other): if isinstance(other, Array): @@ -1710,11 +1737,13 @@ class Array: result.add_event( self._array_comparison(result, self, other, op=">=")) return result - else: + elif np.isscalar(other): result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op=">=")) return result + else: + return NotImplemented def __lt__(self, other): if isinstance(other, Array): @@ -1722,11 +1751,13 @@ class Array: result.add_event( self._array_comparison(result, self, other, op="<")) return result - else: + elif np.isscalar(other): result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op="<")) return result + else: + return NotImplemented def __gt__(self, other): if isinstance(other, Array): @@ -1734,11 +1765,13 @@ class Array: result.add_event( self._array_comparison(result, self, other, op=">")) return result - else: + elif np.isscalar(other): result = self._new_like_me(_BOOL_DTYPE) result.add_event( self._scalar_comparison(result, self, other, op=">")) return result + else: + return NotImplemented # }}}