diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index d5fca2c15c4dfc4fe448975ded1eeda2c77cded7..9d27eaae5f4bf3b86a8312c60789f2373ddde716 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -203,17 +203,27 @@ def randn(shape, dtype): rng = np.random.default_rng() dtype = np.dtype(dtype) + if shape == 0: + ashape = 1 + else: + ashape = shape + if dtype.kind == "c": dtype = np.dtype(f"<f{dtype.itemsize // 2}") - return rng.standard_normal(shape, dtype) \ - + 1j * rng.standard_normal(shape, dtype) + r = rng.standard_normal(ashape, dtype) \ + + 1j * rng.standard_normal(ashape, dtype) elif dtype.kind == "f": - return rng.standard_normal(shape, dtype) + r = rng.standard_normal(ashape, dtype) elif dtype.kind == "i": - return rng.integers(0, 128, shape, dtype) + r = rng.integers(0, 512, ashape, dtype) else: raise TypeError(dtype.kind) + if shape == 0: + return np.array(r[0]) + + return r + def assert_close_to_numpy(actx, op, args): assert np.allclose( @@ -672,11 +682,14 @@ class MyContainerDOFBcast: return self.mass.array_context -def _get_test_containers(actx, ambient_dim=2, size=50_000): - if size == 0: - x = DOFArray(actx, (actx.from_numpy(np.array(np.random.randn())),)) - else: - x = DOFArray(actx, (actx.from_numpy(np.random.randn(size)),)) +def _get_test_containers(actx, ambient_dim=2, shapes=50_000): + from numbers import Number + if isinstance(shapes, (Number, tuple)): + shapes = [shapes] + + x = DOFArray(actx, tuple([ + actx.from_numpy(randn(shape, np.float64)) + for shape in shapes])) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter dataclass_of_dofs = MyContainer( @@ -705,7 +718,7 @@ def _get_test_containers(actx, ambient_dim=2, size=50_000): def test_container_scalar_map(actx_factory): actx = actx_factory() - arys = _get_test_containers(actx, size=0) + arys = _get_test_containers(actx, shapes=0) arys += (np.pi,) from arraycontext import ( @@ -881,11 +894,12 @@ def test_container_norm(actx_factory, ord): def test_flatten_array_container(actx_factory): actx = actx_factory() - if not hasattr(actx.np, "astype"): - pytest.skip(f"'astype' not implemented on '{type(actx).__name__}'") from arraycontext import flatten, unflatten - arys = _get_test_containers(actx, size=512) + arys = _get_test_containers(actx, shapes=512) \ + + _get_test_containers(actx, shapes=(128, 67)) \ + + _get_test_containers(actx, shapes=[(64, 7), (154, 12)]) + for ary in arys: flat = flatten(ary, actx) assert flat.ndim == 1