From 4b7ac4da133e47b2232e58eea52a7bef7802bfa6 Mon Sep 17 00:00:00 2001
From: "Timothy A. Smith" <tasmith4@illinois.edu>
Date: Mon, 27 May 2019 23:46:45 -0500
Subject: [PATCH] refactor halo computations into params class

---
 kernel_fixtures.py | 12 ++----------
 setup_fixtures.py  | 39 ++++++++++++++++-----------------------
 2 files changed, 18 insertions(+), 33 deletions(-)

diff --git a/kernel_fixtures.py b/kernel_fixtures.py
index 0640e8c..e1ff329 100644
--- a/kernel_fixtures.py
+++ b/kernel_fixtures.py
@@ -33,12 +33,8 @@ def compute_flux_derivatives(ctx_factory, params, arrays):
 
     prg = prg.with_kernel(cfd)
 
-    nx_halo = params.nx + params.nhalo
-    ny_halo = params.ny + params.nhalo
-    nz_halo = params.nz + params.nhalo
-
     flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
-        nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F")
+        params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F")
 
     prg(queue, nvars=params.nvars, ndim=params.ndim,
             states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
@@ -51,12 +47,8 @@ def compute_flux_derivatives_gpu(ctx_factory, params, arrays):
 
     queue = fixtures.get_queue(ctx_factory)
 
-    nx_halo = params.nx + params.nhalo
-    ny_halo = params.ny + params.nhalo
-    nz_halo = params.nz + params.nhalo
-
     flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
-        nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F")
+        params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F")
 
     prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
 
diff --git a/setup_fixtures.py b/setup_fixtures.py
index b576866..d3a236c 100644
--- a/setup_fixtures.py
+++ b/setup_fixtures.py
@@ -7,10 +7,15 @@ class FluxDerivativeParams:
     def __init__(self, nvars, ndim, nx, ny, nz):
         self.nvars = nvars
         self.ndim = ndim
+
         self.nx = nx
         self.ny = ny
         self.nz = nz
-        self.nhalo = 6
+
+        self.nhalo = 3
+        self.nx_halo = self.nx + 2*self.nhalo
+        self.ny_halo = self.ny + 2*self.nhalo
+        self.nz_halo = self.nz + 2*self.nhalo
 
 
 class FluxDerivativeArrays:
@@ -25,33 +30,21 @@ def random_array(*shape):
     return np.random.random_sample(shape).astype(np.float32).copy(order="F")
 
 
-def random_flux_derivative_arrays(params):
-    nvars = params.nvars
-    ndim = params.ndim
-    nx_halo = params.nx + params.nhalo
-    ny_halo = params.ny + params.nhalo
-    nz_halo = params.nz + params.nhalo
-
-    states = random_array(nvars, nx_halo, ny_halo, nz_halo)
-    fluxes = random_array(nvars, ndim, nx_halo, ny_halo, nz_halo)
-    metrics = random_array(ndim, ndim, nx_halo, ny_halo, nz_halo)
-    metric_jacobians = random_array(nx_halo, ny_halo, nz_halo)
+def random_flux_derivative_arrays(p):
+    states = random_array(p.nvars, p.nx_halo, p.ny_halo, p.nz_halo)
+    fluxes = random_array(p.nvars, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo)
+    metrics = random_array(p.ndim, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo)
+    metric_jacobians = random_array(p.nx_halo, p.ny_halo, p.nz_halo)
 
     return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)
 
 
-def random_flux_derivative_arrays_on_device(ctx_factory, params):
+def random_flux_derivative_arrays_on_device(ctx_factory, p):
     queue = fixtures.get_queue(ctx_factory)
 
-    nvars = params.nvars
-    ndim = params.ndim
-    nx_halo = params.nx + params.nhalo
-    ny_halo = params.ny + params.nhalo
-    nz_halo = params.nz + params.nhalo
-
-    states = fixtures.f_array(queue, nvars, nx_halo, ny_halo, nz_halo)
-    fluxes = fixtures.f_array(queue, nvars, ndim, nx_halo, ny_halo, nz_halo)
-    metrics = fixtures.f_array(queue, ndim, ndim, nx_halo, ny_halo, nz_halo)
-    metric_jacobians = fixtures.f_array(queue, nx_halo, ny_halo, nz_halo)
+    states = fixtures.f_array(queue, p.nvars, p.nx_halo, p.ny_halo, p.nz_halo)
+    fluxes = fixtures.f_array(queue, p.nvars, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo)
+    metrics = fixtures.f_array(queue, p.ndim, p.ndim, p.nx_halo, p.ny_halo, p.nz_halo)
+    metric_jacobians = fixtures.f_array(queue, p.nx_halo, p.ny_halo, p.nz_halo)
 
     return FluxDerivativeArrays(states, fluxes, metrics, metric_jacobians)
-- 
GitLab