diff --git a/pyopencl/array.py b/pyopencl/array.py index 999c45442a13f1db2aa0d8ff22d6c438aa26bd72..6ccb093fc137bed492546c9b1229cdb045d10f48 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -627,8 +627,13 @@ class Array(object): self._axpbz(result, common_dtype.type(scalar), self, self.dtype.type(0)) return result - def __imul__(self, scalar): - self._axpbz(self, scalar, self, self.dtype.type(0)) + def __imul__(self, other): + if isinstance(other, Array): + self._elwise_multiply(self, self, other) + else: + # scalar + self._axpbz(self, other, self, self.dtype.type(0)) + return self def __div__(self, other):