From c737ec9dd329b9fba82798670b0619464f0b9a5c Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Thu, 14 Jul 2022 10:49:33 -0500
Subject: [PATCH] add failing tests for nan handling in min/max/minimum/maximum

---
 test/test_array.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 63 insertions(+)

diff --git a/test/test_array.py b/test/test_array.py
index 2a1ccfdf..1f9cdcfe 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -1841,6 +1841,32 @@ def test_branch_operations_on_pure_scalars():
 # }}}
 
 
+# {{{ test_branch_operations_on_nans
+
+@pytest.mark.parametrize("op", [
+    cl_array.maximum,
+    cl_array.minimum,
+])
+def test_branch_operations_on_nans(ctx_factory, op):
+    ctx = ctx_factory()
+    cq = cl.CommandQueue(ctx)
+
+    x_np = np.array([np.nan, 1., np.nan, 2.], dtype=np.float64)
+    y_np = np.array([np.nan, np.nan, 1., 3.], dtype=np.float64)
+
+    x_cl = cl_array.to_device(cq, x_np)
+    y_cl = cl_array.to_device(cq, y_np)
+
+    ref = getattr(np, op.__name__)(x_np, y_np)
+    result = op(x_cl, y_cl)
+    if isinstance(result, cl_array.Array):
+        result = result.get()
+
+    np.testing.assert_allclose(result, ref)
+
+# }}}
+
+
 # {{{ test_slice_copy
 
 def test_slice_copy(ctx_factory):
@@ -2063,6 +2089,43 @@ def test_empty_reductions_vs_numpy(ctx_factory, reduction, supports_initial):
 # }}}
 
 
+# {{{ test_reduction_nan_handling
+
+@pytest.mark.parametrize("with_initial", [False, True])
+@pytest.mark.parametrize("input_case", ["only nans", "mixed"])
+@pytest.mark.parametrize("reduction", [
+    cl_array.sum,
+    cl_array.max,
+    cl_array.min,
+    ])
+def test_reduction_nan_handling(ctx_factory, reduction, input_case, with_initial):
+    ctx = ctx_factory()
+    cq = cl.CommandQueue(ctx)
+
+    if input_case == "only nans":
+        x_np = np.array([np.nan, np.nan], dtype=np.float64)
+    elif input_case == "mixed":
+        x_np = np.array([np.nan, 1.], dtype=np.float64)
+    else:
+        raise ValueError("invalid input case")
+
+    x_cl = cl_array.to_device(cq, x_np)
+
+    if with_initial:
+        ref = getattr(np, reduction.__name__)(x_np, initial=5.0)
+        result = reduction(x_cl, initial=5.0)
+    else:
+        ref = getattr(np, reduction.__name__)(x_np)
+        result = reduction(x_cl)
+
+    if isinstance(result, cl_array.Array):
+        result = result.get()
+
+    np.testing.assert_allclose(result, ref)
+
+# }}}
+
+
 # {{{ test_reductions_dtype
 
 def test_dtype_conversions(ctx_factory):
-- 
GitLab