Skip to content
Snippets Groups Projects
Commit 6842c798 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

extend cla.(maximum|minimum|if_positive) to take scalars

parent 3c8b103e
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,9 @@ from pyopencl.compyte.array import ( ...@@ -42,6 +42,9 @@ from pyopencl.compyte.array import (
get_common_dtype as _get_common_dtype_base) get_common_dtype as _get_common_dtype_base)
from pyopencl.characterize import has_double_support from pyopencl.characterize import has_double_support
from pyopencl import cltypes from pyopencl import cltypes
from numbers import Number
SCALAR_CLASSES = (Number, np.number, np.bool_, bool)
_COMMON_DTYPE_CACHE = {} _COMMON_DTYPE_CACHE = {}
...@@ -2704,6 +2707,17 @@ def if_positive(criterion, then_, else_, out=None, queue=None): ...@@ -2704,6 +2707,17 @@ def if_positive(criterion, then_, else_, out=None, queue=None):
contains *then_[i]* if *criterion[i]>0*, else *else_[i]*. contains *then_[i]* if *criterion[i]>0*, else *else_[i]*.
""" """
if all(isinstance(k, SCALAR_CLASSES) for k in [criterion,
then_,
else_]):
result = np.where(criterion, then_, else_)
if out is not None:
out[...] = result
return
return result
if not (criterion.shape == then_.shape == else_.shape): if not (criterion.shape == then_.shape == else_.shape):
raise ValueError("shapes do not match") raise ValueError("shapes do not match")
...@@ -2719,6 +2733,13 @@ def if_positive(criterion, then_, else_, out=None, queue=None): ...@@ -2719,6 +2733,13 @@ def if_positive(criterion, then_, else_, out=None, queue=None):
def maximum(a, b, out=None, queue=None): def maximum(a, b, out=None, queue=None):
"""Return the elementwise maximum of *a* and *b*.""" """Return the elementwise maximum of *a* and *b*."""
if all(isinstance(k, SCALAR_CLASSES) for k in [a, b]):
result = np.maximum(a, b)
if out is not None:
out[...] = result
return
return result
# silly, but functional # silly, but functional
return if_positive(a.mul_add(1, b, -1, queue=queue), a, b, return if_positive(a.mul_add(1, b, -1, queue=queue), a, b,
...@@ -2727,6 +2748,13 @@ def maximum(a, b, out=None, queue=None): ...@@ -2727,6 +2748,13 @@ def maximum(a, b, out=None, queue=None):
def minimum(a, b, out=None, queue=None): def minimum(a, b, out=None, queue=None):
"""Return the elementwise minimum of *a* and *b*.""" """Return the elementwise minimum of *a* and *b*."""
if all(isinstance(k, SCALAR_CLASSES) for k in [a, b]):
result = np.minimum(a, b)
if out is not None:
out[...] = result
return
return result
# silly, but functional # silly, but functional
return if_positive(a.mul_add(1, b, -1, queue=queue), b, a, return if_positive(a.mul_add(1, b, -1, queue=queue), b, a,
queue=queue, out=out) queue=queue, out=out)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment