From 464a0383f800ba0bd5f135136e28ae050f6c9f23 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 13 Jul 2021 14:12:03 -0500
Subject: [PATCH] extend some tests to use complex inputs

---
 test/test_arraycontext.py | 79 ++++++++++++++++++++++-----------------
 1 file changed, 44 insertions(+), 35 deletions(-)

diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 4642806..c9c26e2 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -188,6 +188,22 @@ def _thaw_dofarray(ary, actx):
 
 # {{{ assert_close_to_numpy*
 
+def randn(shape, dtype):
+    rng = np.random.default_rng()
+    dtype = np.dtype(dtype)
+
+    if dtype.kind == "c":
+        dtype = np.dtype(f"<f{dtype.itemsize // 2}")
+        return rng.standard_normal(shape, dtype) \
+            + 1j * rng.standard_normal(shape, dtype)
+    elif dtype.kind == "f":
+        return rng.standard_normal(shape, dtype)
+    elif dtype.kind == "i":
+        return rng.integers(0, 128, shape, dtype)
+    else:
+        raise TypeError(dtype.kind)
+
+
 def assert_close_to_numpy(actx, op, args):
     assert np.allclose(
             actx.to_numpy(
@@ -236,37 +252,44 @@ def assert_close_to_numpy_in_containers(actx, op, args):
 
 # {{{ np.function same as numpy
 
-@pytest.mark.parametrize(("sym_name", "n_args"), [
-            ("sin", 1),
-            ("exp", 1),
-            ("arctan2", 2),
-            ("minimum", 2),
-            ("maximum", 2),
-            ("where", 3),
-            ("conj", 1),
-            ("vdot", 2),
+@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [
+            ("sin", 1, np.float64),
+            ("sin", 1, np.complex128),
+            ("exp", 1, np.float64),
+            ("arctan2", 2, np.float64),
+            ("minimum", 2, np.float64),
+            ("maximum", 2, np.float64),
+            ("where", 3, np.float64),
+            ("conj", 1, np.float64),
+            ("conj", 1, np.complex128),
+            ("vdot", 2, np.float64),
+            ("vdot", 2, np.complex128),
+            ("abs", 1, np.float64),
+            ("abs", 1, np.complex128),
             ])
-def test_array_context_np_workalike(actx_factory, sym_name, n_args):
+def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype):
     actx = actx_factory()
     if not hasattr(actx.np, sym_name):
         pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'")
 
-    ndofs = 5000
-    args = [np.random.randn(ndofs) for i in range(n_args)]
+    ndofs = 512
+    args = [randn(ndofs, dtype) for i in range(n_args)]
+
     assert_close_to_numpy_in_containers(
             actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args)
 
 
-@pytest.mark.parametrize(("sym_name", "n_args"), [
-            # ("empty_like", 1),    # NOTE: fails np.allclose, obviously
-            ("zeros_like", 1),
-            ("ones_like", 1),
+@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [
+            ("zeros_like", 1, np.float64),
+            ("zeros_like", 1, np.complex128),
+            ("ones_like", 1, np.float64),
+            ("ones_like", 1, np.complex128),
             ])
-def test_array_context_np_like(actx_factory, sym_name, n_args):
+def test_array_context_np_like(actx_factory, sym_name, n_args, dtype):
     actx = actx_factory()
 
-    ndofs = 5000
-    args = [np.random.randn(ndofs) for i in range(n_args)]
+    ndofs = 512
+    args = [randn(ndofs, dtype) for i in range(n_args)]
     assert_close_to_numpy(
             actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args)
 
@@ -469,6 +492,7 @@ def test_dof_array_arithmetic_same_as_numpy(actx_factory):
 
 
 # {{{ reductions same as numpy
+
 @pytest.mark.parametrize("op", ["sum", "min", "max"])
 def test_dof_array_reductions_same_as_numpy(actx_factory, op):
     actx = actx_factory()
@@ -794,7 +818,7 @@ def test_numpy_conversion(actx_factory):
 @pytest.mark.parametrize("norm_ord", [2, np.inf])
 def test_norm_complex(actx_factory, norm_ord):
     actx = actx_factory()
-    a = np.random.randn(2000) + 1j * np.random.randn(2000)
+    a = randn(2000, np.complex128)
 
     norm_a_ref = np.linalg.norm(a, norm_ord)
     norm_a = actx.np.linalg.norm(actx.from_numpy(a), norm_ord)
@@ -883,21 +907,6 @@ def test_container_equality(actx_factory):
 # }}}
 
 
-# {{{ test_abs_complex
-
-def test_abs_complex(actx_factory):
-    actx = actx_factory()
-    a = np.random.randn(2000) + 1j * np.random.randn(2000)
-
-    abs_a_ref = np.abs(a)
-    abs_a = actx.np.abs(actx.from_numpy(a))
-
-    assert abs_a.dtype == abs_a_ref.dtype
-    np.testing.assert_allclose(actx.to_numpy(abs_a), abs_a_ref)
-
-# }}}
-
-
 # {{{ test_leaf_array_type_broadcasting
 
 @with_container_arithmetic(
-- 
GitLab