Skip to content
Snippets Groups Projects
test_eigensystem.py 3.75 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    import numpy.linalg as la  # noqa: F401
    import pyopencl as cl  # noqa: F401
    import pyopencl.array  # noqa
    import pyopencl.tools  # noqa
    import pyopencl.clrandom  # noqa
    import loopy as lp  # noqa
    
    import sys
    import logging
    
    import pytest
    from pyopencl.tools import (  # noqa
            pytest_generate_tests_for_pyopencl
            as pytest_generate_tests)
    
    import utilities as u
    from data_for_test import (  # noqa: F401
            flux_test_data_fixture,
            single_data as sd
            )
    
    
    def test_pointwise_eigenvalues_ideal_gas(ctx_factory, flux_test_data_fixture):
        data = flux_test_data_fixture
    
        prg = u.get_weno_program_with_root_kernel("pointwise_eigenvalues")
        queue = u.get_queue(ctx_factory)
    
        lam_dev = u.empty_array_on_device(queue, data.nvars, 6)
    
        prg(queue, nvars=data.nvars, d=data.direction,
                states=data.states, lambda_pointwise=lam_dev)
    
        u.compare_arrays(lam_dev.get(), data.lam_pointwise)
    
    
    def test_roe_uniform_grid_ideal_gas(ctx_factory, flux_test_data_fixture):
        data = flux_test_data_fixture
    
        def check_roe_identity(states, R, R_inv):
            d_state = states[:, 1] - states[:, 0]
            u.compare_arrays(R@(R_inv@d_state), d_state)
    
        def check_roe_property(states, fluxes, R, R_inv, lam):
            d_state = states[:, 1] - states[:, 0]
            d_flux = fluxes[:, 1] - fluxes[:, 0]
    
            temp = R_inv@d_state
            temp = np.multiply(lam, temp)
            u.compare_arrays(R@temp, d_flux)
    
        prg = u.get_weno_program_with_root_kernel("roe_eigensystem")
        queue = u.get_queue(ctx_factory)
    
        R_dev = u.empty_array_on_device(queue, data.nvars, data.nvars)
        R_inv_dev = u.empty_array_on_device(queue, data.nvars, data.nvars)
        lam_dev = u.empty_array_on_device(queue, data.nvars)
    
        prg(queue, nvars=data.nvars, ndim=data.ndim, d=data.direction,
    
                states=data.state_pair, metrics_frozen=data.frozen_metrics,
    
                R=R_dev, R_inv=R_inv_dev, lambda_roe=lam_dev)
    
        R = R_dev.get()
        R_inv = R_inv_dev.get()
        lam = lam_dev.get()
    
        check_roe_identity(data.state_pair, R, R_inv)
        check_roe_property(data.state_pair, data.flux_pair, R, R_inv, lam)
    
    
    @pytest.mark.parametrize("lam_pointwise_str,lam_roe_str,lam_expected_str", [
        ("1 2 3 4 5,2 4 6 8 10", "1.5 3 4.5 6 7.5", "2.2 4.4 6.6 8.8 11"),
        ("1 2 3 4 5,-2 -4 -6 -8 -10", "1.5 3 4.5 6 7.5", "2.2 4.4 6.6 8.8 11"),
        ("1 2 3 4 5,-2 -4 -6 -8 -10", "3 6 9 12 15", "3.3 6.6 9.9 13.2 16.5"),
        ("1 2 3 4 5,2 4 6 8 10", "-3 -6 -9 -12 -15", "3.3 6.6 9.9 13.2 16.5"),
        ("3 2 9 4 5,2 6 6 12 10", "-1 -4 -3 -8 -15", "3.3 6.6 9.9 13.2 16.5")
        ])
    def test_lax_wavespeeds(
            ctx_factory, lam_pointwise_str, lam_roe_str, lam_expected_str):
        prg = u.get_weno_program_with_root_kernel("lax_wavespeeds")
        queue = u.get_queue(ctx_factory)
    
        nvars = 5
    
        lam_pointwise = u.expand_to_6(u.transposed_array_from_string(lam_pointwise_str))
        lam_roe = u.array_from_string(lam_roe_str)
        lam_dev = u.empty_array_on_device(queue, nvars)
    
        prg(queue, nvars=nvars, lambda_pointwise=lam_pointwise,
                lambda_roe=lam_roe, wavespeeds=lam_dev)
    
        lam_expected = u.array_from_string(lam_expected_str)
        u.compare_arrays(lam_dev.get(), lam_expected)
    
    
    def test_matvec(ctx_factory):
        prg = u.get_weno_program_with_root_kernel("mult_mat_vec")
        queue = u.get_queue(ctx_factory)
    
        a = u.random_array_on_device(queue, 10, 10)
        b = u.random_array_on_device(queue, 10)
    
        c = u.empty_array_on_device(queue, 10)
    
        prg(queue, alpha=1.0, a=a, b=b, c=c)
    
        u.compare_arrays(a.get()@b.get(), c.get())
    
    
    # This lets you run 'python test.py test_case(cl._csc)' without pytest.
    if __name__ == "__main__":
        if len(sys.argv) > 1:
            logging.basicConfig(level="INFO")
            exec(sys.argv[1])
        else:
            pytest.main([__file__])