Skip to content
test.py 2.17 KiB
Newer Older
import logging

from pyopencl.tools import (  # noqa
        pytest_generate_tests_for_pyopencl
        as pytest_generate_tests)

import device_fixtures as device
import program_fixtures as program
import transform_fixtures as transform
import setup_fixtures as setup
import comparison_fixtures as compare
def test_roe(ctx_factory):
    queue = device.get_queue(ctx_factory)
    prg = program.get_weno()

    params = setup.roe_params(nvars=5, ndim=3, direction="x")
    states = setup.array_from_string("2 1,4 1,4 1,4 1,20 5.5")
    metrics_frozen = setup.identity(params.ndim)
    R, Rinv, lam = kernel.roe_eigensystem(queue, prg, params, states, metrics_frozen)

    compare.roe_identity(states, R, Rinv)

Timothy A. Smith's avatar
Timothy A. Smith committed
    fluxes = setup.array_from_string("4 1,11.2 2.6,8 1,8 1,46.4 7.1")
    compare.roe_property(states, fluxes, R, Rinv, lam)
def test_matvec(ctx_factory):
    queue = device.get_queue(ctx_factory)
    prg = program.get_weno()
    a = setup.random_array(10, 10)
    b = setup.random_array(10)
    c = kernel.mult_mat_vec(queue, prg, alpha=1.0, a=a, b=b)
def test_compute_flux_derivatives(ctx_factory):
    queue = device.get_queue(ctx_factory)
    prg = program.get_weno()
    prg = transform.compute_flux_derivative_basic(prg)

    params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
    arrays = setup.random_flux_derivative_arrays(params)
    kernel.compute_flux_derivatives(queue, prg, params, arrays)
Andreas Klöckner's avatar
Andreas Klöckner committed
def test_compute_flux_derivatives_gpu(ctx_factory):
    queue = device.get_queue(ctx_factory)
    prg = program.get_weno()
    prg = transform.compute_flux_derivative_gpu(queue, prg)

    params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
    arrays = setup.random_flux_derivative_arrays_on_device(ctx_factory, params)
    kernel.compute_flux_derivatives(queue, prg, params, arrays)
Andreas Klöckner's avatar
Andreas Klöckner committed
# This lets you run 'python test.py test_case(cl._csc)' without pytest.
if __name__ == "__main__":
    if len(sys.argv) > 1:
        exec(sys.argv[1])
    else:
        pytest.main([__file__])