diff --git a/pyopencl/array.py b/pyopencl/array.py index b8fde35ed0cc91d5161307b76ca011cc5cb5fa1e..cb66121a5d632a44ab0b78c00fc8a2d19a6f4502 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2942,14 +2942,30 @@ def maximum(a, b, out=None, queue=None): return result # silly, but functional - if isinstance(a, Array): + a_is_array = isinstance(a, Array) + b_is_array = isinstance(b, Array) + if a_is_array: criterion = a.mul_add(1, b, -1, queue=queue) - elif isinstance(b, Array): + elif b_is_array: criterion = b.mul_add(-1, a, 1, queue=queue) else: raise AssertionError - return if_positive(criterion, a, b, queue=queue, out=out) + # {{{ propagate NaNs + + if a_is_array: + a_with_nan = a.mul_add(1, b, 0, queue=queue) + else: + a_with_nan = b.mul_add(0, a, 1, queue=queue) + + if b_is_array: + b_with_nan = b.mul_add(1, a, 0, queue=queue) + else: + b_with_nan = a.mul_add(0, b, 1, queue=queue) + + # }}} + + return if_positive(criterion, a_with_nan, b_with_nan, queue=queue, out=out) def minimum(a, b, out=None, queue=None): @@ -2963,14 +2979,30 @@ def minimum(a, b, out=None, queue=None): return result # silly, but functional - if isinstance(a, Array): + a_is_array = isinstance(a, Array) + b_is_array = isinstance(b, Array) + if a_is_array: criterion = a.mul_add(1, b, -1, queue=queue) - elif isinstance(b, Array): + elif b_is_array: criterion = b.mul_add(-1, a, 1, queue=queue) else: raise AssertionError - return if_positive(criterion, b, a, queue=queue, out=out) + # {{{ propagate NaNs + + if a_is_array: + a_with_nan = a.mul_add(1, b, 0, queue=queue) + else: + a_with_nan = b.mul_add(0, a, 1, queue=queue) + + if b_is_array: + b_with_nan = b.mul_add(1, a, 0, queue=queue) + else: + b_with_nan = a.mul_add(0, b, 1, queue=queue) + + # }}} + + return if_positive(criterion, b_with_nan, a_with_nan, queue=queue, out=out) # }}} diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py index f8853feb7ec3d5fc1f5a6e692f71156bd0786a1c..92bf51eb1cd1ce9fd7a858e08fe2cc3374fe589d 100644 --- a/pyopencl/reduction.py +++ b/pyopencl/reduction.py @@ -649,7 +649,7 @@ def get_minmax_neutral(what, dtype): @context_dependent_memoize def get_minmax_kernel(ctx, what, dtype): if dtype.kind == "f": - reduce_expr = "f%s(a,b)" % what + reduce_expr = "f%s_nanprop(a,b)" % what elif dtype.kind in "iu": reduce_expr = "%s(a,b)" % what else: @@ -660,7 +660,11 @@ def get_minmax_kernel(ctx, what, dtype): reduce_expr=f"{reduce_expr}", arguments="const {tp} *in".format( tp=dtype_to_ctype(dtype), - ), preamble="#define MY_INFINITY (1./0)") + ), preamble=""" + #define MY_INFINITY (1./0) + #define fmin_nanprop(a, b) (isnan(a) || isnan(b)) ? a+b : fmin(a, b) + #define fmax_nanprop(a, b) (isnan(a) || isnan(b)) ? a+b : fmax(a, b) + """) @context_dependent_memoize