diff --git a/pyopencl/array.py b/pyopencl/array.py index a991b237fc1f5d73fe45384090a74fde59f07233..8364dbe9ef10fe58b69e369c91342afe53b41110 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1086,10 +1086,25 @@ class Array: def mul_add(self, selffac, other, otherfac, queue=None): """Return `selffac * self + otherfac*other`. """ - result = _get_broadcasted_binary_op_result(self, other, queue or self.queue) - result.add_event( - self._axpbyz(result, selffac, self, otherfac, other)) - return result + queue = queue or self.queue + + if isinstance(other, Array): + result = _get_broadcasted_binary_op_result(self, other, queue) + result.add_event( + self._axpbyz( + result, selffac, self, otherfac, other, + queue=queue)) + return result + elif np.isscalar(other): + common_dtype = _get_common_dtype(self, other, queue) + result = self._new_like_me(common_dtype, queue=queue) + result.add_event( + self._axpbz(result, selffac, + self, common_dtype.type(otherfac * other), + queue=queue)) + return result + else: + raise NotImplementedError def __add__(self, other): """Add an array with an array or an array with a scalar.""" @@ -2869,8 +2884,14 @@ def maximum(a, b, out=None, queue=None): return result # silly, but functional - return if_positive(a.mul_add(1, b, -1, queue=queue), a, b, - queue=queue, out=out) + if isinstance(a, Array): + criterion = a.mul_add(1, b, -1, queue=queue) + elif isinstance(b, Array): + criterion = b.mul_add(-1, a, 1, queue=queue) + else: + raise AssertionError + + return if_positive(criterion, a, b, queue=queue, out=out) def minimum(a, b, out=None, queue=None): @@ -2882,9 +2903,16 @@ def minimum(a, b, out=None, queue=None): return out return result + # silly, but functional - return if_positive(a.mul_add(1, b, -1, queue=queue), b, a, - queue=queue, out=out) + if isinstance(a, Array): + criterion = a.mul_add(1, b, -1, queue=queue) + elif isinstance(b, Array): + criterion = b.mul_add(-1, a, 1, queue=queue) + else: + raise AssertionError + + return if_positive(criterion, b, a, queue=queue, out=out) # }}} diff --git a/test/test_array.py b/test/test_array.py index f0510bf89acac583dbe4068db67eba8f9d5e71eb..a94e36ac903cc70db6a4c9e84ebed230954605c1 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1680,6 +1680,31 @@ def test_if_positive_with_scalars(ctx_factory, then_type, else_type): np.testing.assert_allclose(result_cl.get(), result_np) +def test_maximum_minimum_with_scalars(ctx_factory): + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + a_np = np.float64(4.0) + a_cl = cl_array.to_device(cq, np.array(a_np)).with_queue(None) + + b_np = np.float64(-3.0) + b_cl = cl_array.to_device(cq, np.array(b_np)).with_queue(None) + + result = cl_array.maximum(a_np, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), a_np) + result = cl_array.maximum(a_cl, b_np, queue=cq) + np.testing.assert_allclose(result.get(), a_np) + result = cl_array.maximum(a_cl, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), a_np) + + result = cl_array.minimum(a_np, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), b_np) + result = cl_array.minimum(a_cl, b_np, queue=cq) + np.testing.assert_allclose(result.get(), b_np) + result = cl_array.minimum(a_cl, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), b_np) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])