From 464a0383f800ba0bd5f135136e28ae050f6c9f23 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Tue, 13 Jul 2021 14:12:03 -0500 Subject: [PATCH] extend some tests to use complex inputs --- test/test_arraycontext.py | 79 ++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 4642806..c9c26e2 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -188,6 +188,22 @@ def _thaw_dofarray(ary, actx): # {{{ assert_close_to_numpy* +def randn(shape, dtype): + rng = np.random.default_rng() + dtype = np.dtype(dtype) + + if dtype.kind == "c": + dtype = np.dtype(f"<f{dtype.itemsize // 2}") + return rng.standard_normal(shape, dtype) \ + + 1j * rng.standard_normal(shape, dtype) + elif dtype.kind == "f": + return rng.standard_normal(shape, dtype) + elif dtype.kind == "i": + return rng.integers(0, 128, shape, dtype) + else: + raise TypeError(dtype.kind) + + def assert_close_to_numpy(actx, op, args): assert np.allclose( actx.to_numpy( @@ -236,37 +252,44 @@ def assert_close_to_numpy_in_containers(actx, op, args): # {{{ np.function same as numpy -@pytest.mark.parametrize(("sym_name", "n_args"), [ - ("sin", 1), - ("exp", 1), - ("arctan2", 2), - ("minimum", 2), - ("maximum", 2), - ("where", 3), - ("conj", 1), - ("vdot", 2), +@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ + ("sin", 1, np.float64), + ("sin", 1, np.complex128), + ("exp", 1, np.float64), + ("arctan2", 2, np.float64), + ("minimum", 2, np.float64), + ("maximum", 2, np.float64), + ("where", 3, np.float64), + ("conj", 1, np.float64), + ("conj", 1, np.complex128), + ("vdot", 2, np.float64), + ("vdot", 2, np.complex128), + ("abs", 1, np.float64), + ("abs", 1, np.complex128), ]) -def test_array_context_np_workalike(actx_factory, sym_name, n_args): +def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): actx = actx_factory() if not hasattr(actx.np, sym_name): pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") - ndofs = 5000 - args = [np.random.randn(ndofs) for i in range(n_args)] + ndofs = 512 + args = [randn(ndofs, dtype) for i in range(n_args)] + assert_close_to_numpy_in_containers( actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args) -@pytest.mark.parametrize(("sym_name", "n_args"), [ - # ("empty_like", 1), # NOTE: fails np.allclose, obviously - ("zeros_like", 1), - ("ones_like", 1), +@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ + ("zeros_like", 1, np.float64), + ("zeros_like", 1, np.complex128), + ("ones_like", 1, np.float64), + ("ones_like", 1, np.complex128), ]) -def test_array_context_np_like(actx_factory, sym_name, n_args): +def test_array_context_np_like(actx_factory, sym_name, n_args, dtype): actx = actx_factory() - ndofs = 5000 - args = [np.random.randn(ndofs) for i in range(n_args)] + ndofs = 512 + args = [randn(ndofs, dtype) for i in range(n_args)] assert_close_to_numpy( actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args) @@ -469,6 +492,7 @@ def test_dof_array_arithmetic_same_as_numpy(actx_factory): # {{{ reductions same as numpy + @pytest.mark.parametrize("op", ["sum", "min", "max"]) def test_dof_array_reductions_same_as_numpy(actx_factory, op): actx = actx_factory() @@ -794,7 +818,7 @@ def test_numpy_conversion(actx_factory): @pytest.mark.parametrize("norm_ord", [2, np.inf]) def test_norm_complex(actx_factory, norm_ord): actx = actx_factory() - a = np.random.randn(2000) + 1j * np.random.randn(2000) + a = randn(2000, np.complex128) norm_a_ref = np.linalg.norm(a, norm_ord) norm_a = actx.np.linalg.norm(actx.from_numpy(a), norm_ord) @@ -883,21 +907,6 @@ def test_container_equality(actx_factory): # }}} -# {{{ test_abs_complex - -def test_abs_complex(actx_factory): - actx = actx_factory() - a = np.random.randn(2000) + 1j * np.random.randn(2000) - - abs_a_ref = np.abs(a) - abs_a = actx.np.abs(actx.from_numpy(a)) - - assert abs_a.dtype == abs_a_ref.dtype - np.testing.assert_allclose(actx.to_numpy(abs_a), abs_a_ref) - -# }}} - - # {{{ test_leaf_array_type_broadcasting @with_container_arithmetic( -- GitLab