import sys
import logging

import pytest
from pyopencl.tools import (  # noqa
        pytest_generate_tests_for_pyopencl
        as pytest_generate_tests)

import device_fixtures as device
import program_fixtures as program
import transform_fixtures as transform
import setup_fixtures as setup
import kernel_fixtures as kernel
import comparison_fixtures as compare

def test_matvec(ctx_factory):
    queue = device.get_queue(ctx_factory)
    prg = program.get_weno()

    a = setup.random_array(10, 10)
    b = setup.random_array(10)

    c = kernel.mult_mat_vec(queue, prg, alpha=1.0, a=a, b=b)

    compare.arrays(a@b, c)


def test_compute_flux_derivatives(ctx_factory):
    queue = device.get_queue(ctx_factory)
    prg = program.get_weno()
    prg = transform.compute_flux_derivative_basic(prg)

    params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
    arrays = setup.random_flux_derivative_arrays(params)

    kernel.compute_flux_derivatives(queue, prg, params, arrays)


def test_compute_flux_derivatives_gpu(ctx_factory):
    queue = device.get_queue(ctx_factory)
    prg = program.get_weno()
    prg = transform.compute_flux_derivative_gpu(queue, prg)

    params = setup.flux_derivative_params(ndim=3, nvars=5, n=10)
    arrays = setup.random_flux_derivative_arrays_on_device(ctx_factory, params)

    kernel.compute_flux_derivatives(queue, prg, params, arrays)


# This lets you run 'python test.py test_case(cl._csc)' without pytest.
if __name__ == "__main__":
    if len(sys.argv) > 1:
        logging.basicConfig(level="INFO")
        exec(sys.argv[1])
    else:
        pytest.main([__file__])
