From 6d8c70e3b68c17808c72708f5c6b873799615b19 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 17 Oct 2021 19:50:23 -0500 Subject: [PATCH] ensure and maximum works with combinations of scalars --- pyopencl/array.py | 44 ++++++++++++++++++++++++++++++++++++-------- test/test_array.py | 25 +++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index a991b237..8364dbe9 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1086,10 +1086,25 @@ class Array: def mul_add(self, selffac, other, otherfac, queue=None): """Return `selffac * self + otherfac*other`. """ - result = _get_broadcasted_binary_op_result(self, other, queue or self.queue) - result.add_event( - self._axpbyz(result, selffac, self, otherfac, other)) - return result + queue = queue or self.queue + + if isinstance(other, Array): + result = _get_broadcasted_binary_op_result(self, other, queue) + result.add_event( + self._axpbyz( + result, selffac, self, otherfac, other, + queue=queue)) + return result + elif np.isscalar(other): + common_dtype = _get_common_dtype(self, other, queue) + result = self._new_like_me(common_dtype, queue=queue) + result.add_event( + self._axpbz(result, selffac, + self, common_dtype.type(otherfac * other), + queue=queue)) + return result + else: + raise NotImplementedError def __add__(self, other): """Add an array with an array or an array with a scalar.""" @@ -2869,8 +2884,14 @@ def maximum(a, b, out=None, queue=None): return result # silly, but functional - return if_positive(a.mul_add(1, b, -1, queue=queue), a, b, - queue=queue, out=out) + if isinstance(a, Array): + criterion = a.mul_add(1, b, -1, queue=queue) + elif isinstance(b, Array): + criterion = b.mul_add(-1, a, 1, queue=queue) + else: + raise AssertionError + + return if_positive(criterion, a, b, queue=queue, out=out) def minimum(a, b, out=None, queue=None): @@ -2882,9 +2903,16 @@ def minimum(a, b, out=None, queue=None): return out return result + # silly, but functional - return if_positive(a.mul_add(1, b, -1, queue=queue), b, a, - queue=queue, out=out) + if isinstance(a, Array): + criterion = a.mul_add(1, b, -1, queue=queue) + elif isinstance(b, Array): + criterion = b.mul_add(-1, a, 1, queue=queue) + else: + raise AssertionError + + return if_positive(criterion, b, a, queue=queue, out=out) # }}} diff --git a/test/test_array.py b/test/test_array.py index f0510bf8..a94e36ac 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1680,6 +1680,31 @@ def test_if_positive_with_scalars(ctx_factory, then_type, else_type): np.testing.assert_allclose(result_cl.get(), result_np) +def test_maximum_minimum_with_scalars(ctx_factory): + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + a_np = np.float64(4.0) + a_cl = cl_array.to_device(cq, np.array(a_np)).with_queue(None) + + b_np = np.float64(-3.0) + b_cl = cl_array.to_device(cq, np.array(b_np)).with_queue(None) + + result = cl_array.maximum(a_np, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), a_np) + result = cl_array.maximum(a_cl, b_np, queue=cq) + np.testing.assert_allclose(result.get(), a_np) + result = cl_array.maximum(a_cl, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), a_np) + + result = cl_array.minimum(a_np, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), b_np) + result = cl_array.minimum(a_cl, b_np, queue=cq) + np.testing.assert_allclose(result.get(), b_np) + result = cl_array.minimum(a_cl, b_cl, queue=cq) + np.testing.assert_allclose(result.get(), b_np) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab