diff --git a/test.py b/test.py index 6b42fdd3a6e9f47fd5d4d4d50f413fdd4a5eb2f1..3db55b64b1c0a52b6b9550862a4ab850bc455b3f 100644 --- a/test.py +++ b/test.py @@ -16,12 +16,43 @@ from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) -import device_fixtures as device -import program_fixtures as program import setup_fixtures as setup import kernel_fixtures as kernel +_QUEUE = [] + + +def get_queue(ctx_factory): + if not _QUEUE: + setup_queue(ctx_factory) + return _QUEUE[0] + + +def setup_queue(ctx_factory): + ctx = ctx_factory() + _QUEUE.append(cl.CommandQueue(ctx)) + + +_WENO_PRG = [] + + +def parse_weno(): + fn = "WENO.F90" + + with open(fn, "r") as infile: + infile_content = infile.read() + + prg = lp.parse_transformed_fortran(infile_content, filename=fn) + _WENO_PRG.append(prg) + + +def get_weno_program(): + if not _WENO_PRG: + parse_weno() + return _WENO_PRG[0] + + def compare_arrays(a, b): assert a == approx(b) @@ -117,8 +148,8 @@ def transform_compute_flux_derivative_gpu(queue, prg): ("2 1,4 1,8 2,12 3,64 11", "12 3,24 3,48 6,75.2 10.6,403.2 37.8", "z") ]) def test_roe_uniform_grid(ctx_factory, states_str, fluxes_str, direction): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() + queue = get_queue(ctx_factory) + prg = get_weno_program() params = setup.roe_params(nvars=5, ndim=3, direction=direction) states = setup.array_from_string(states_str) @@ -132,8 +163,8 @@ def test_roe_uniform_grid(ctx_factory, states_str, fluxes_str, direction): def test_matvec(ctx_factory): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() + queue = get_queue(ctx_factory) + prg = get_weno_program() a = setup.random_array(10, 10) b = setup.random_array(10) @@ -145,8 +176,8 @@ def test_matvec(ctx_factory): #@pytest.mark.slow def test_compute_flux_derivatives(ctx_factory): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() + queue = get_queue(ctx_factory) + prg = get_weno_program() prg = transform_compute_flux_derivative_basic(prg) params = setup.flux_derivative_params(ndim=3, nvars=5, n=10) @@ -157,8 +188,8 @@ def test_compute_flux_derivatives(ctx_factory): #@pytest.mark.slow def test_compute_flux_derivatives_gpu(ctx_factory): - queue = device.get_queue(ctx_factory) - prg = program.get_weno() + queue = get_queue(ctx_factory) + prg = get_weno_program() prg = transform_compute_flux_derivative_gpu(queue, prg) params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)