From a6f068a663ec45e52813aa8e478026e83fd543f3 Mon Sep 17 00:00:00 2001 From: "Timothy A. Smith" <tasmith4@illinois.edu> Date: Tue, 28 May 2019 23:38:46 -0500 Subject: [PATCH] add new fixture for setting up an empty array on the device --- kernel_fixtures.py | 7 ++++--- setup_fixtures.py | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/kernel_fixtures.py b/kernel_fixtures.py index 58d19bc..278d5f7 100644 --- a/kernel_fixtures.py +++ b/kernel_fixtures.py @@ -6,6 +6,7 @@ import loopy as lp # noqa import device_fixtures as device import program_fixtures as program import transform_fixtures as transform +import setup_fixtures as setup def with_root_kernel(prg, root_name): @@ -21,7 +22,7 @@ def with_root_kernel(prg, root_name): def mult_mat_vec(queue, prg, alpha, a, b): - c_dev = cl.array.empty(queue, *b.shape, dtype=np.float32) + c_dev = setup.empty_array_on_device(queue, b.shape) prg = with_root_kernel(prg, "mult_mat_vec") prg(queue, a=a, b=b, c=c_dev, alpha=alpha) @@ -30,8 +31,8 @@ def mult_mat_vec(queue, prg, alpha, a, b): def compute_flux_derivatives(queue, prg, params, arrays): - flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim, - params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F") + flux_derivatives_dev = setup.empty_array_on_device(queue, (params.nvars, params.ndim, + params.nx_halo, params.ny_halo, params.nz_halo)) prg(queue, nvars=params.nvars, ndim=params.ndim, states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics, diff --git a/setup_fixtures.py b/setup_fixtures.py index add7fea..255adfc 100644 --- a/setup_fixtures.py +++ b/setup_fixtures.py @@ -43,6 +43,10 @@ def flux_derivative_params(nvars, ndim, n): return FluxDerivativeParams(nvars, ndim, n, n, n) +def empty_array_on_device(queue, shape): + return cl.array.empty(queue, shape, dtype=np.float32, order="F") + + def random_array(*shape): return np.random.random_sample(shape).astype(np.float32).copy(order="F") -- GitLab