diff --git a/pyopencl/array.py b/pyopencl/array.py index 33ab96832b949f4d0735a32d970f037c74780e6f..c05ab2f0039c6868284fa73520fa4fec7da4982d 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -42,6 +42,9 @@ from pyopencl.compyte.array import ( get_common_dtype as _get_common_dtype_base) from pyopencl.characterize import has_double_support from pyopencl import cltypes +from numbers import Number + +SCALAR_CLASSES = (Number, np.bool_, bool) _COMMON_DTYPE_CACHE = {} @@ -2704,6 +2707,17 @@ def if_positive(criterion, then_, else_, out=None, queue=None): contains *then_[i]* if *criterion[i]>0*, else *else_[i]*. """ + if (isinstance(criterion, SCALAR_CLASSES) + and isinstance(then_, SCALAR_CLASSES) + and isinstance(else_, SCALAR_CLASSES)): + result = np.where(criterion, then_, else_) + + if out is not None: + out[...] = result + return out + + return result + if not (criterion.shape == then_.shape == else_.shape): raise ValueError("shapes do not match") @@ -2719,6 +2733,13 @@ def if_positive(criterion, then_, else_, out=None, queue=None): def maximum(a, b, out=None, queue=None): """Return the elementwise maximum of *a* and *b*.""" + if isinstance(a, SCALAR_CLASSES) and isinstance(b, SCALAR_CLASSES): + result = np.maximum(a, b) + if out is not None: + out[...] = result + return out + + return result # silly, but functional return if_positive(a.mul_add(1, b, -1, queue=queue), a, b, @@ -2727,6 +2748,13 @@ def maximum(a, b, out=None, queue=None): def minimum(a, b, out=None, queue=None): """Return the elementwise minimum of *a* and *b*.""" + if isinstance(a, SCALAR_CLASSES) and isinstance(b, SCALAR_CLASSES): + result = np.minimum(a, b) + if out is not None: + out[...] = result + return out + + return result # silly, but functional return if_positive(a.mul_add(1, b, -1, queue=queue), b, a, queue=queue, out=out) diff --git a/test/test_array.py b/test/test_array.py index a6a91bef4e11089bb4223a5279420bcd155c6334..3bfd0d04bb463174e40fefa59304ca1b1e9b4f8f 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1560,6 +1560,19 @@ def test_assign_different_strides(ctx_factory): b[:] = a +def test_branch_operations_on_pure_scalars(ctx_factory): + x = np.random.rand() + y = np.random.rand() + cond = np.random.choice([False, True]) + + np.testing.assert_allclose(np.maximum(x, y), + cl_array.maximum(x, y)) + np.testing.assert_allclose(np.minimum(x, y), + cl_array.minimum(x, y)) + np.testing.assert_allclose(np.where(cond, x, y), + cl_array.if_positive(cond, x, y)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])