diff --git a/kernel_fixtures.py b/kernel_fixtures.py index 0640e8ce493a7cad6cf860e621aa43a5e831cd19..e1ff32968372902be6f60a69286b1955f48fe7c9 100644 --- a/kernel_fixtures.py +++ b/kernel_fixtures.py @@ -33,12 +33,8 @@ def compute_flux_derivatives(ctx_factory, params, arrays): prg = prg.with_kernel(cfd) - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim, - nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F") + params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F") prg(queue, nvars=params.nvars, ndim=params.ndim, states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics, @@ -51,12 +47,8 @@ def compute_flux_derivatives_gpu(ctx_factory, params, arrays): queue = fixtures.get_queue(ctx_factory) - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim, - nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F") + params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F") prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) diff --git a/setup_fixtures.py b/setup_fixtures.py index b576866e758d40ae9a633cbb9f9cdb281dda8212..d3a236c07f2ea0df5193b70a122f09ae0923597c 100644 --- a/setup_fixtures.py +++ b/setup_fixtures.py @@ -7,10 +7,15 @@ class FluxDerivativeParams: def __init__(self, nvars, ndim, nx, ny, nz): self.nvars = nvars self.ndim = ndim + self.nx = nx self.ny = ny self.nz = nz - self.nhalo = 6 + + self.nhalo = 3 + self.nx_halo = self.nx + 2*self.nhalo + self.ny_halo = self.ny + 2*self.nhalo + self.nz_halo = self.nz + 2*self.nhalo class FluxDerivativeArrays: @@ -25,33 +30,21 @@ def random_array(*shape): return np.random.random_sample(shape).astype(np.float32).copy(order="F") -def random_flux_derivative_arrays(params): - nvars = params.nvars - ndim = params.ndim - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - - states = random_array(nvars, nx_halo, ny_halo, nz_halo) - fluxes = random_array(nvars, ndim, nx_halo, ny_halo, nz_halo) - metrics = random_array(ndim, ndim, nx_halo, ny_halo, nz_halo) - metric_jacobians = random_array(nx_halo, ny_halo, nz_halo) +def random_flux_derivative_arrays(p): + states = random_array(p.nvars, p.nx_halo, p.ny_halo, p.nz_halo) + fluxes = random_array(p.nvars, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metrics = random_array(p.ndim, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metric_jacobians = random_array(p.nx_halo, p.ny_halo, p.nz_halo) return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians) -def random_flux_derivative_arrays_on_device(ctx_factory, params): +def random_flux_derivative_arrays_on_device(ctx_factory, p): queue = fixtures.get_queue(ctx_factory) - nvars = params.nvars - ndim = params.ndim - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - - states = fixtures.f_array(queue, nvars, nx_halo, ny_halo, nz_halo) - fluxes = fixtures.f_array(queue, nvars, ndim, nx_halo, ny_halo, nz_halo) - metrics = fixtures.f_array(queue, ndim, ndim, nx_halo, ny_halo, nz_halo) - metric_jacobians = fixtures.f_array(queue, nx_halo, ny_halo, nz_halo) + states = fixtures.f_array(queue, p.nvars, p.nx_halo, p.ny_halo, p.nz_halo) + fluxes = fixtures.f_array(queue, p.nvars, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metrics = fixtures.f_array(queue, p.ndim, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metric_jacobians = fixtures.f_array(queue, p.nx_halo, p.ny_halo, p.nz_halo) return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)