diff --git a/kernel_fixtures.py b/kernel_fixtures.py index 58d19bcd849229ef350ec0d72298b0b2ca627233..278d5f7d4f69a565f5083a49a23d3dfa00c5701b 100644 --- a/kernel_fixtures.py +++ b/kernel_fixtures.py @@ -6,6 +6,7 @@ import loopy as lp # noqa import device_fixtures as device import program_fixtures as program import transform_fixtures as transform +import setup_fixtures as setup def with_root_kernel(prg, root_name): @@ -21,7 +22,7 @@ def with_root_kernel(prg, root_name): def mult_mat_vec(queue, prg, alpha, a, b): - c_dev = cl.array.empty(queue, *b.shape, dtype=np.float32) + c_dev = setup.empty_array_on_device(queue, b.shape) prg = with_root_kernel(prg, "mult_mat_vec") prg(queue, a=a, b=b, c=c_dev, alpha=alpha) @@ -30,8 +31,8 @@ def mult_mat_vec(queue, prg, alpha, a, b): def compute_flux_derivatives(queue, prg, params, arrays): - flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim, - params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F") + flux_derivatives_dev = setup.empty_array_on_device(queue, (params.nvars, params.ndim, + params.nx_halo, params.ny_halo, params.nz_halo)) prg(queue, nvars=params.nvars, ndim=params.ndim, states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics, diff --git a/setup_fixtures.py b/setup_fixtures.py index add7feaa199359ae018321b60eb150bd8fdf9365..255adfcac99e21cd5e89037a020009f9c6d278bf 100644 --- a/setup_fixtures.py +++ b/setup_fixtures.py @@ -43,6 +43,10 @@ def flux_derivative_params(nvars, ndim, n): return FluxDerivativeParams(nvars, ndim, n, n, n) +def empty_array_on_device(queue, shape): + return cl.array.empty(queue, shape, dtype=np.float32, order="F") + + def random_array(*shape): return np.random.random_sample(shape).astype(np.float32).copy(order="F")