From 6842c798e416a6cb63a75047c0ec873de4d1ca7d Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sun, 4 Apr 2021 23:23:06 -0500
Subject: [PATCH] extend cla.(maximum|minimum|if_positive) to take scalars

---
 pyopencl/array.py | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index 33ab9683..6e37a5e6 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -42,6 +42,9 @@ from pyopencl.compyte.array import (
         get_common_dtype as _get_common_dtype_base)
 from pyopencl.characterize import has_double_support
 from pyopencl import cltypes
+from numbers import Number
+
+SCALAR_CLASSES = (Number, np.number, np.bool_, bool)
 
 
 _COMMON_DTYPE_CACHE = {}
@@ -2704,6 +2707,17 @@ def if_positive(criterion, then_, else_, out=None, queue=None):
     contains *then_[i]* if *criterion[i]>0*, else *else_[i]*.
     """
 
+    if all(isinstance(k, SCALAR_CLASSES) for k in [criterion,
+                                                   then_,
+                                                   else_]):
+        result = np.where(criterion, then_, else_)
+
+        if out is not None:
+            out[...] = result
+            return
+
+        return result
+
     if not (criterion.shape == then_.shape == else_.shape):
         raise ValueError("shapes do not match")
 
@@ -2719,6 +2733,13 @@ def if_positive(criterion, then_, else_, out=None, queue=None):
 
 def maximum(a, b, out=None, queue=None):
     """Return the elementwise maximum of *a* and *b*."""
+    if all(isinstance(k, SCALAR_CLASSES) for k in [a, b]):
+        result = np.maximum(a, b)
+        if out is not None:
+            out[...] = result
+            return
+
+        return result
 
     # silly, but functional
     return if_positive(a.mul_add(1, b, -1, queue=queue), a, b,
@@ -2727,6 +2748,13 @@ def maximum(a, b, out=None, queue=None):
 
 def minimum(a, b, out=None, queue=None):
     """Return the elementwise minimum of *a* and *b*."""
+    if all(isinstance(k, SCALAR_CLASSES) for k in [a, b]):
+        result = np.minimum(a, b)
+        if out is not None:
+            out[...] = result
+            return
+
+        return result
     # silly, but functional
     return if_positive(a.mul_add(1, b, -1, queue=queue), b, a,
             queue=queue, out=out)
-- 
GitLab