diff --git a/pyopencl/array.py b/pyopencl/array.py index 33ab96832b949f4d0735a32d970f037c74780e6f..6e37a5e66cb7029aea2549cf25777b155b4d28f6 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -42,6 +42,9 @@ from pyopencl.compyte.array import ( get_common_dtype as _get_common_dtype_base) from pyopencl.characterize import has_double_support from pyopencl import cltypes +from numbers import Number + +SCALAR_CLASSES = (Number, np.number, np.bool_, bool) _COMMON_DTYPE_CACHE = {} @@ -2704,6 +2707,17 @@ def if_positive(criterion, then_, else_, out=None, queue=None): contains *then_[i]* if *criterion[i]>0*, else *else_[i]*. """ + if all(isinstance(k, SCALAR_CLASSES) for k in [criterion, + then_, + else_]): + result = np.where(criterion, then_, else_) + + if out is not None: + out[...] = result + return + + return result + if not (criterion.shape == then_.shape == else_.shape): raise ValueError("shapes do not match") @@ -2719,6 +2733,13 @@ def if_positive(criterion, then_, else_, out=None, queue=None): def maximum(a, b, out=None, queue=None): """Return the elementwise maximum of *a* and *b*.""" + if all(isinstance(k, SCALAR_CLASSES) for k in [a, b]): + result = np.maximum(a, b) + if out is not None: + out[...] = result + return + + return result # silly, but functional return if_positive(a.mul_add(1, b, -1, queue=queue), a, b, @@ -2727,6 +2748,13 @@ def maximum(a, b, out=None, queue=None): def minimum(a, b, out=None, queue=None): """Return the elementwise minimum of *a* and *b*.""" + if all(isinstance(k, SCALAR_CLASSES) for k in [a, b]): + result = np.minimum(a, b) + if out is not None: + out[...] = result + return + + return result # silly, but functional return if_positive(a.mul_add(1, b, -1, queue=queue), b, a, queue=queue, out=out)