Skip to content
Snippets Groups Projects
Commit 8aa2cf58 authored by Timothy A. Smith's avatar Timothy A. Smith
Browse files

pass queue, prg directly to kernel fixture for mult_mat_vec

parent 1edc017d
No related branches found
No related tags found
1 merge request!8Test refactoring
......@@ -5,12 +5,10 @@ import loopy as lp # noqa
import fixtures
def mult_mat_vec(ctx_factory, alpha, a, b):
queue = fixtures.get_queue(ctx_factory)
def mult_mat_vec(queue, prg, alpha, a, b):
c_dev = cl.array.empty(queue, 10, dtype=np.float32)
prg = fixtures.with_root_kernel(fixtures.get_weno_program(), "mult_mat_vec")
prg = fixtures.with_root_kernel(prg, "mult_mat_vec")
prg(queue, a=a, b=b, c=c_dev, alpha=alpha)
return c_dev.get()
......
......@@ -10,16 +10,22 @@ import comparison_fixtures as compare
import setup_fixtures as setup
import kernel_fixtures as kernel
import fixtures
import device_fixtures as device
def test_matvec(ctx_factory):
queue = device.get_queue(ctx_factory)
prg = fixtures.get_weno_program()
a = setup.random_array(10, 10)
b = setup.random_array(10)
c = kernel.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0)
c = kernel.mult_mat_vec(queue, prg, alpha=1.0, a=a, b=b)
compare.arrays(a@b, c)
@pytest.mark.skip("slow")
def test_compute_flux_derivatives(ctx_factory):
params = setup.FluxDerivativeParams(ndim=3, nvars=5, nx=10, ny=10, nz=10)
arrays = setup.random_flux_derivative_arrays(params)
......@@ -27,6 +33,7 @@ def test_compute_flux_derivatives(ctx_factory):
kernel.compute_flux_derivatives(ctx_factory, params, arrays)
@pytest.mark.skip("slow")
def test_compute_flux_derivatives_gpu(ctx_factory):
params = setup.FluxDerivativeParams(ndim=3, nvars=5, nx=10, ny=10, nz=10)
arrays = setup.random_flux_derivative_arrays_on_device(ctx_factory, params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment