From 4b7ac4da133e47b2232e58eea52a7bef7802bfa6 Mon Sep 17 00:00:00 2001 From: "Timothy A. Smith" <tasmith4@illinois.edu> Date: Mon, 27 May 2019 23:46:45 -0500 Subject: [PATCH] refactor halo computations into params class --- kernel_fixtures.py | 12 ++---------- setup_fixtures.py | 39 ++++++++++++++++----------------------- 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/kernel_fixtures.py b/kernel_fixtures.py index 0640e8c..e1ff329 100644 --- a/kernel_fixtures.py +++ b/kernel_fixtures.py @@ -33,12 +33,8 @@ def compute_flux_derivatives(ctx_factory, params, arrays): prg = prg.with_kernel(cfd) - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim, - nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F") + params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F") prg(queue, nvars=params.nvars, ndim=params.ndim, states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics, @@ -51,12 +47,8 @@ def compute_flux_derivatives_gpu(ctx_factory, params, arrays): queue = fixtures.get_queue(ctx_factory) - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim, - nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F") + params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F") prg = prg.copy(target=lp.PyOpenCLTarget(queue.device)) diff --git a/setup_fixtures.py b/setup_fixtures.py index b576866..d3a236c 100644 --- a/setup_fixtures.py +++ b/setup_fixtures.py @@ -7,10 +7,15 @@ class FluxDerivativeParams: def __init__(self, nvars, ndim, nx, ny, nz): self.nvars = nvars self.ndim = ndim + self.nx = nx self.ny = ny self.nz = nz - self.nhalo = 6 + + self.nhalo = 3 + self.nx_halo = self.nx + 2*self.nhalo + self.ny_halo = self.ny + 2*self.nhalo + self.nz_halo = self.nz + 2*self.nhalo class FluxDerivativeArrays: @@ -25,33 +30,21 @@ def random_array(*shape): return np.random.random_sample(shape).astype(np.float32).copy(order="F") -def random_flux_derivative_arrays(params): - nvars = params.nvars - ndim = params.ndim - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - - states = random_array(nvars, nx_halo, ny_halo, nz_halo) - fluxes = random_array(nvars, ndim, nx_halo, ny_halo, nz_halo) - metrics = random_array(ndim, ndim, nx_halo, ny_halo, nz_halo) - metric_jacobians = random_array(nx_halo, ny_halo, nz_halo) +def random_flux_derivative_arrays(p): + states = random_array(p.nvars, p.nx_halo, p.ny_halo, p.nz_halo) + fluxes = random_array(p.nvars, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metrics = random_array(p.ndim, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metric_jacobians = random_array(p.nx_halo, p.ny_halo, p.nz_halo) return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians) -def random_flux_derivative_arrays_on_device(ctx_factory, params): +def random_flux_derivative_arrays_on_device(ctx_factory, p): queue = fixtures.get_queue(ctx_factory) - nvars = params.nvars - ndim = params.ndim - nx_halo = params.nx + params.nhalo - ny_halo = params.ny + params.nhalo - nz_halo = params.nz + params.nhalo - - states = fixtures.f_array(queue, nvars, nx_halo, ny_halo, nz_halo) - fluxes = fixtures.f_array(queue, nvars, ndim, nx_halo, ny_halo, nz_halo) - metrics = fixtures.f_array(queue, ndim, ndim, nx_halo, ny_halo, nz_halo) - metric_jacobians = fixtures.f_array(queue, nx_halo, ny_halo, nz_halo) + states = fixtures.f_array(queue, p.nvars, p.nx_halo, p.ny_halo, p.nz_halo) + fluxes = fixtures.f_array(queue, p.nvars, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metrics = fixtures.f_array(queue, p.ndim, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo) + metric_jacobians = fixtures.f_array(queue, p.nx_halo, p.ny_halo, p.nz_halo) return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians) -- GitLab