From a6f068a663ec45e52813aa8e478026e83fd543f3 Mon Sep 17 00:00:00 2001
From: "Timothy A. Smith" <tasmith4@illinois.edu>
Date: Tue, 28 May 2019 23:38:46 -0500
Subject: [PATCH] add new fixture for setting up an empty array on the device

---
 kernel_fixtures.py | 7 ++++---
 setup_fixtures.py  | 4 ++++
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/kernel_fixtures.py b/kernel_fixtures.py
index 58d19bc..278d5f7 100644
--- a/kernel_fixtures.py
+++ b/kernel_fixtures.py
@@ -6,6 +6,7 @@ import loopy as lp  # noqa
 import device_fixtures as device
 import program_fixtures as program
 import transform_fixtures as transform
+import setup_fixtures as setup
 
 
 def with_root_kernel(prg, root_name):
@@ -21,7 +22,7 @@ def with_root_kernel(prg, root_name):
 
 
 def mult_mat_vec(queue, prg, alpha, a, b):
-    c_dev = cl.array.empty(queue, *b.shape, dtype=np.float32)
+    c_dev = setup.empty_array_on_device(queue, b.shape)
 
     prg = with_root_kernel(prg, "mult_mat_vec")
     prg(queue, a=a, b=b, c=c_dev, alpha=alpha)
@@ -30,8 +31,8 @@ def mult_mat_vec(queue, prg, alpha, a, b):
 
 
 def compute_flux_derivatives(queue, prg, params, arrays):
-    flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
-        params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F")
+    flux_derivatives_dev = setup.empty_array_on_device(queue, (params.nvars, params.ndim,
+        params.nx_halo, params.ny_halo, params.nz_halo))
 
     prg(queue, nvars=params.nvars, ndim=params.ndim,
             states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
diff --git a/setup_fixtures.py b/setup_fixtures.py
index add7fea..255adfc 100644
--- a/setup_fixtures.py
+++ b/setup_fixtures.py
@@ -43,6 +43,10 @@ def flux_derivative_params(nvars, ndim, n):
     return FluxDerivativeParams(nvars, ndim, n, n, n)
 
 
+def empty_array_on_device(queue, shape):
+    return cl.array.empty(queue, shape, dtype=np.float32, order="F")
+
+
 def random_array(*shape):
     return np.random.random_sample(shape).astype(np.float32).copy(order="F")
 
-- 
GitLab