From 23ff0e556d43f8c154ffaa36feca2431ca25efe5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 25 Jul 2022 10:25:47 -0400 Subject: [PATCH] Fix NaN propagation in cla.m{in,ax}{,imum} --- pyopencl/array.py | 44 +++++++++++++++++++++++++++++++++++++------ pyopencl/reduction.py | 8 ++++++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index b8fde35e..cb66121a 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2942,14 +2942,30 @@ def maximum(a, b, out=None, queue=None): return result # silly, but functional - if isinstance(a, Array): + a_is_array = isinstance(a, Array) + b_is_array = isinstance(b, Array) + if a_is_array: criterion = a.mul_add(1, b, -1, queue=queue) - elif isinstance(b, Array): + elif b_is_array: criterion = b.mul_add(-1, a, 1, queue=queue) else: raise AssertionError - return if_positive(criterion, a, b, queue=queue, out=out) + # {{{ propagate NaNs + + if a_is_array: + a_with_nan = a.mul_add(1, b, 0, queue=queue) + else: + a_with_nan = b.mul_add(0, a, 1, queue=queue) + + if b_is_array: + b_with_nan = b.mul_add(1, a, 0, queue=queue) + else: + b_with_nan = a.mul_add(0, b, 1, queue=queue) + + # }}} + + return if_positive(criterion, a_with_nan, b_with_nan, queue=queue, out=out) def minimum(a, b, out=None, queue=None): @@ -2963,14 +2979,30 @@ def minimum(a, b, out=None, queue=None): return result # silly, but functional - if isinstance(a, Array): + a_is_array = isinstance(a, Array) + b_is_array = isinstance(b, Array) + if a_is_array: criterion = a.mul_add(1, b, -1, queue=queue) - elif isinstance(b, Array): + elif b_is_array: criterion = b.mul_add(-1, a, 1, queue=queue) else: raise AssertionError - return if_positive(criterion, b, a, queue=queue, out=out) + # {{{ propagate NaNs + + if a_is_array: + a_with_nan = a.mul_add(1, b, 0, queue=queue) + else: + a_with_nan = b.mul_add(0, a, 1, queue=queue) + + if b_is_array: + b_with_nan = b.mul_add(1, a, 0, queue=queue) + else: + b_with_nan = a.mul_add(0, b, 1, queue=queue) + + # }}} + + return if_positive(criterion, b_with_nan, a_with_nan, queue=queue, out=out) # }}} diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py index f8853feb..92bf51eb 100644 --- a/pyopencl/reduction.py +++ b/pyopencl/reduction.py @@ -649,7 +649,7 @@ def get_minmax_neutral(what, dtype): @context_dependent_memoize def get_minmax_kernel(ctx, what, dtype): if dtype.kind == "f": - reduce_expr = "f%s(a,b)" % what + reduce_expr = "f%s_nanprop(a,b)" % what elif dtype.kind in "iu": reduce_expr = "%s(a,b)" % what else: @@ -660,7 +660,11 @@ def get_minmax_kernel(ctx, what, dtype): reduce_expr=f"{reduce_expr}", arguments="const {tp} *in".format( tp=dtype_to_ctype(dtype), - ), preamble="#define MY_INFINITY (1./0)") + ), preamble=""" + #define MY_INFINITY (1./0) + #define fmin_nanprop(a, b) (isnan(a) || isnan(b)) ? a+b : fmin(a, b) + #define fmax_nanprop(a, b) (isnan(a) || isnan(b)) ? a+b : fmax(a, b) + """) @context_dependent_memoize -- GitLab