From 340b78d89b363f7013db3c27edffc2bdc125af9a Mon Sep 17 00:00:00 2001 From: "Timothy A. Smith" <tasmith4@illinois.edu> Date: Mon, 27 May 2019 22:36:52 -0500 Subject: [PATCH] create new kernel_fixtures.py for direct interface with Loopy kernels --- fixtures.py | 10 ---------- kernel_fixtures.py | 15 +++++++++++++++ test.py | 6 ++++-- 3 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 kernel_fixtures.py diff --git a/fixtures.py b/fixtures.py index d11810c..10d0f45 100644 --- a/fixtures.py +++ b/fixtures.py @@ -92,16 +92,6 @@ def f_array(queue, *shape): ary = np.random.random_sample(shape).astype(np.float32).copy(order="F") return cl.array.to_device(queue, ary) -def mult_mat_vec(ctx_factory, alpha, a, b): - queue = get_queue(ctx_factory) - - c_dev = cl.array.empty(queue, 10, dtype=np.float32) - - prg = with_root_kernel(get_weno_program(), "mult_mat_vec") - prg(queue, a=a, b=b, c=c_dev, alpha=alpha) - - return c_dev.get() - def compute_flux_derivatives(ctx_factory, nvars, ndim, nx, ny, nz, states, fluxes, metrics, metric_jacobians): diff --git a/kernel_fixtures.py b/kernel_fixtures.py new file mode 100644 index 0000000..d480a5d --- /dev/null +++ b/kernel_fixtures.py @@ -0,0 +1,15 @@ +import numpy as np +import pyopencl as cl + +import fixtures + +def mult_mat_vec(ctx_factory, alpha, a, b): + queue = fixtures.get_queue(ctx_factory) + + c_dev = cl.array.empty(queue, 10, dtype=np.float32) + + prg = fixtures.with_root_kernel(fixtures.get_weno_program(), "mult_mat_vec") + prg(queue, a=a, b=b, c=c_dev, alpha=alpha) + + return c_dev.get() + diff --git a/test.py b/test.py index 7a8d406..555e953 100644 --- a/test.py +++ b/test.py @@ -17,17 +17,19 @@ from pyopencl.tools import ( # noqa import fixtures import comparison_fixtures as compare import setup_fixtures as setup +import kernel_fixtures as kernel + def test_matvec(ctx_factory): a = setup.random_array(10, 10) b = setup.random_array(10) - c = fixtures.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0) + c = kernel.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0) compare.arrays(a@b, c) -#@pytest.mark.skip("slow") +@pytest.mark.skip("slow") def test_compute_flux_derivatives(ctx_factory): logging.basicConfig(level="INFO") -- GitLab