Skip to content
Snippets Groups Projects
Commit eece8668 authored by Timothy Smith's avatar Timothy Smith
Browse files

create a variable to hold output dimensions

parent a95a43aa
No related branches found
No related tags found
1 merge request!51RHS tests
...@@ -48,6 +48,7 @@ class GridResults: ...@@ -48,6 +48,7 @@ class GridResults:
ndim = 3 ndim = 3
nvars = 5 nvars = 5
dir_map = {"xi":1, "eta":2, "zeta":3} dir_map = {"xi":1, "eta":2, "zeta":3}
output_dims = (nvars, nx, ny, nz)
def __init__(self, dir_str): def __init__(self, dir_str):
self.direction = self.dir_map[dir_str] self.direction = self.dir_map[dir_str]
...@@ -56,7 +57,7 @@ class GridResults: ...@@ -56,7 +57,7 @@ class GridResults:
self.generalized_fluxes = np.full((self.nvars, self.nxhalo, self.nyhalo, self.nzhalo), 1.0, dtype=np.float64, order="F") self.generalized_fluxes = np.full((self.nvars, self.nxhalo, self.nyhalo, self.nzhalo), 1.0, dtype=np.float64, order="F")
self.metrics = np.full((self.ndim, self.ndim, self.nxhalo, self.nyhalo, self.nzhalo), 1.0, dtype=np.float64, order="F") self.metrics = np.full((self.ndim, self.ndim, self.nxhalo, self.nyhalo, self.nzhalo), 1.0, dtype=np.float64, order="F")
self.jacobians = np.full((self.nxhalo, self.nyhalo, self.nzhalo), 1.0, dtype=np.float64, order="F") self.jacobians = np.full((self.nxhalo, self.nyhalo, self.nzhalo), 1.0, dtype=np.float64, order="F")
self.generalized_flux_derivatives = np.full((self.nvars, self.nx, self.ny, self.nz), 0.0) self.generalized_flux_derivatives = np.full(self.output_dims, 0.0)
@pytest.fixture(scope="session", params=["xi", "eta", "zeta"]) @pytest.fixture(scope="session", params=["xi", "eta", "zeta"])
...@@ -70,7 +71,7 @@ def test_compute_flux_derivatives(queue, grid_results): ...@@ -70,7 +71,7 @@ def test_compute_flux_derivatives(queue, grid_results):
prg = u.get_weno_program() prg = u.get_weno_program()
prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
flux_derivatives_dev = u.empty_array_on_device(queue, data.nvars,data.nx,data.ny,data.nz) flux_derivatives_dev = u.empty_array_on_device(queue, *data.output_dims)
prg(queue, prg(queue,
nvars=data.nvars, nvars=data.nvars,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment