Skip to content
Snippets Groups Projects
Commit 340b78d8 authored by Timothy A. Smith's avatar Timothy A. Smith
Browse files

create new kernel_fixtures.py for direct interface with Loopy kernels

parent 578dc7c4
No related branches found
No related tags found
1 merge request!8Test refactoring
......@@ -92,16 +92,6 @@ def f_array(queue, *shape):
ary = np.random.random_sample(shape).astype(np.float32).copy(order="F")
return cl.array.to_device(queue, ary)
def mult_mat_vec(ctx_factory, alpha, a, b):
queue = get_queue(ctx_factory)
c_dev = cl.array.empty(queue, 10, dtype=np.float32)
prg = with_root_kernel(get_weno_program(), "mult_mat_vec")
prg(queue, a=a, b=b, c=c_dev, alpha=alpha)
return c_dev.get()
def compute_flux_derivatives(ctx_factory,
nvars, ndim, nx, ny, nz,
states, fluxes, metrics, metric_jacobians):
......
import numpy as np
import pyopencl as cl
import fixtures
def mult_mat_vec(ctx_factory, alpha, a, b):
queue = fixtures.get_queue(ctx_factory)
c_dev = cl.array.empty(queue, 10, dtype=np.float32)
prg = fixtures.with_root_kernel(fixtures.get_weno_program(), "mult_mat_vec")
prg(queue, a=a, b=b, c=c_dev, alpha=alpha)
return c_dev.get()
......@@ -17,17 +17,19 @@ from pyopencl.tools import ( # noqa
import fixtures
import comparison_fixtures as compare
import setup_fixtures as setup
import kernel_fixtures as kernel
def test_matvec(ctx_factory):
a = setup.random_array(10, 10)
b = setup.random_array(10)
c = fixtures.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0)
c = kernel.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0)
compare.arrays(a@b, c)
#@pytest.mark.skip("slow")
@pytest.mark.skip("slow")
def test_compute_flux_derivatives(ctx_factory):
logging.basicConfig(level="INFO")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment