From c737ec9dd329b9fba82798670b0619464f0b9a5c Mon Sep 17 00:00:00 2001 From: Matthew Smith <mjsmith6@illinois.edu> Date: Thu, 14 Jul 2022 10:49:33 -0500 Subject: [PATCH] add failing tests for nan handling in min/max/minimum/maximum --- test/test_array.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/test/test_array.py b/test/test_array.py index 2a1ccfdf..1f9cdcfe 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1841,6 +1841,32 @@ def test_branch_operations_on_pure_scalars(): # }}} +# {{{ test_branch_operations_on_nans + +@pytest.mark.parametrize("op", [ + cl_array.maximum, + cl_array.minimum, +]) +def test_branch_operations_on_nans(ctx_factory, op): + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + x_np = np.array([np.nan, 1., np.nan, 2.], dtype=np.float64) + y_np = np.array([np.nan, np.nan, 1., 3.], dtype=np.float64) + + x_cl = cl_array.to_device(cq, x_np) + y_cl = cl_array.to_device(cq, y_np) + + ref = getattr(np, op.__name__)(x_np, y_np) + result = op(x_cl, y_cl) + if isinstance(result, cl_array.Array): + result = result.get() + + np.testing.assert_allclose(result, ref) + +# }}} + + # {{{ test_slice_copy def test_slice_copy(ctx_factory): @@ -2063,6 +2089,43 @@ def test_empty_reductions_vs_numpy(ctx_factory, reduction, supports_initial): # }}} +# {{{ test_reduction_nan_handling + +@pytest.mark.parametrize("with_initial", [False, True]) +@pytest.mark.parametrize("input_case", ["only nans", "mixed"]) +@pytest.mark.parametrize("reduction", [ + cl_array.sum, + cl_array.max, + cl_array.min, + ]) +def test_reduction_nan_handling(ctx_factory, reduction, input_case, with_initial): + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + if input_case == "only nans": + x_np = np.array([np.nan, np.nan], dtype=np.float64) + elif input_case == "mixed": + x_np = np.array([np.nan, 1.], dtype=np.float64) + else: + raise ValueError("invalid input case") + + x_cl = cl_array.to_device(cq, x_np) + + if with_initial: + ref = getattr(np, reduction.__name__)(x_np, initial=5.0) + result = reduction(x_cl, initial=5.0) + else: + ref = getattr(np, reduction.__name__)(x_np) + result = reduction(x_cl) + + if isinstance(result, cl_array.Array): + result = result.get() + + np.testing.assert_allclose(result, ref) + +# }}} + + # {{{ test_reductions_dtype def test_dtype_conversions(ctx_factory): -- GitLab