Skip to content
Snippets Groups Projects
Commit c737ec9d authored by Matt Smith's avatar Matt Smith Committed by Andreas Klöckner
Browse files

add failing tests for nan handling in min/max/minimum/maximum

parent 77967e43
No related branches found
No related tags found
No related merge requests found
...@@ -1841,6 +1841,32 @@ def test_branch_operations_on_pure_scalars(): ...@@ -1841,6 +1841,32 @@ def test_branch_operations_on_pure_scalars():
# }}} # }}}
# {{{ test_branch_operations_on_nans
@pytest.mark.parametrize("op", [
cl_array.maximum,
cl_array.minimum,
])
def test_branch_operations_on_nans(ctx_factory, op):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)
x_np = np.array([np.nan, 1., np.nan, 2.], dtype=np.float64)
y_np = np.array([np.nan, np.nan, 1., 3.], dtype=np.float64)
x_cl = cl_array.to_device(cq, x_np)
y_cl = cl_array.to_device(cq, y_np)
ref = getattr(np, op.__name__)(x_np, y_np)
result = op(x_cl, y_cl)
if isinstance(result, cl_array.Array):
result = result.get()
np.testing.assert_allclose(result, ref)
# }}}
# {{{ test_slice_copy # {{{ test_slice_copy
def test_slice_copy(ctx_factory): def test_slice_copy(ctx_factory):
...@@ -2063,6 +2089,43 @@ def test_empty_reductions_vs_numpy(ctx_factory, reduction, supports_initial): ...@@ -2063,6 +2089,43 @@ def test_empty_reductions_vs_numpy(ctx_factory, reduction, supports_initial):
# }}} # }}}
# {{{ test_reduction_nan_handling
@pytest.mark.parametrize("with_initial", [False, True])
@pytest.mark.parametrize("input_case", ["only nans", "mixed"])
@pytest.mark.parametrize("reduction", [
cl_array.sum,
cl_array.max,
cl_array.min,
])
def test_reduction_nan_handling(ctx_factory, reduction, input_case, with_initial):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)
if input_case == "only nans":
x_np = np.array([np.nan, np.nan], dtype=np.float64)
elif input_case == "mixed":
x_np = np.array([np.nan, 1.], dtype=np.float64)
else:
raise ValueError("invalid input case")
x_cl = cl_array.to_device(cq, x_np)
if with_initial:
ref = getattr(np, reduction.__name__)(x_np, initial=5.0)
result = reduction(x_cl, initial=5.0)
else:
ref = getattr(np, reduction.__name__)(x_np)
result = reduction(x_cl)
if isinstance(result, cl_array.Array):
result = result.get()
np.testing.assert_allclose(result, ref)
# }}}
# {{{ test_reductions_dtype # {{{ test_reductions_dtype
def test_dtype_conversions(ctx_factory): def test_dtype_conversions(ctx_factory):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment