diff --git a/setup_fixtures.py b/setup_fixtures.py index 3b51fceb8bc8cfeace243262f57aba6122a48d66..9e24722a226ba9d7cdc82869dbbe39e5dcb568b8 100644 --- a/setup_fixtures.py +++ b/setup_fixtures.py @@ -8,6 +8,7 @@ class FluxDerivativeParams: self.nx = nx self.ny = ny self.nz = nz + self.nhalo = 6 class FluxDerivativeArrays: @@ -20,3 +21,18 @@ class FluxDerivativeArrays: def random_array(*shape): return np.random.random_sample(shape).astype(np.float32).copy(order="F") + + +def random_flux_derivative_arrays(params): + nvars = params.nvars + ndim = params.ndim + nx_halo = params.nx + params.nhalo + ny_halo = params.ny + params.nhalo + nz_halo = params.nz + params.nhalo + + states = random_array(nvars, nx_halo, ny_halo, nz_halo) + fluxes = random_array(nvars, ndim, nx_halo, ny_halo, nz_halo) + metrics = random_array(ndim, ndim, nx_halo, ny_halo, nz_halo) + metric_jacobians = random_array(nx_halo, ny_halo, nz_halo) + + return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians) diff --git a/test.py b/test.py index 999959baf28590499328ef9c50694150b4ad3d69..d99f84f030c10aa3ce2bbb6507dc3cd5a0c66504 100644 --- a/test.py +++ b/test.py @@ -31,6 +31,8 @@ def test_matvec(ctx_factory): def test_compute_flux_derivatives(ctx_factory): params = setup.FluxDerivativeParams(ndim=3, nvars=5, nx=10, ny=10, nz=10) + arrays = setup.random_flux_derivative_arrays(params) + ndim = 3 nvars = 5 nx = 10