import numpy as np
import numpy.linalg as la
import pyopencl as cl
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa
import sys

import logging

import pytest
from pyopencl.tools import (  # noqa
        pytest_generate_tests_for_pyopencl
        as pytest_generate_tests)

import fixtures
import comparison_fixtures as compare
import setup_fixtures as setup
import kernel_fixtures as kernel


def test_matvec(ctx_factory):
    a = setup.random_array(10, 10)
    b = setup.random_array(10)

    c = kernel.mult_mat_vec(ctx_factory, a=a, b=b, alpha=1.0)

    compare.arrays(a@b, c)


def test_compute_flux_derivatives(ctx_factory):
    params = setup.FluxDerivativeParams(ndim=3, nvars=5, nx=10, ny=10, nz=10)
    ndim = 3
    nvars = 5
    nx = 10
    ny = 10
    nz = 10

    states = setup.random_array(nvars, nx+6, ny+6, nz+6)
    fluxes = setup.random_array(nvars, ndim, nx+6, ny+6, nz+6)
    metrics = setup.random_array(ndim, ndim, nx+6, ny+6, nz+6)
    metric_jacobians = setup.random_array(nx+6, ny+6, nz+6)

    fixtures.compute_flux_derivatives(ctx_factory,
            nvars=nvars, ndim=ndim, nx=nx, ny=ny, nz=nz,
            states=states, fluxes=fluxes, metrics=metrics,
            metric_jacobians=metric_jacobians)


@pytest.mark.skip("slow")
def test_compute_flux_derivatives_gpu(ctx_factory):
    prg = fixtures.get_gpu_transformed_weno()

    queue = fixtures.get_queue(ctx_factory)

    ndim = 3
    nvars = 5
    nx = 10
    ny = 10
    nz = 10

    states = fixtures.f_array(queue, nvars, nx+6, ny+6, nz+6)
    fluxes = fixtures.f_array(queue, nvars, ndim, nx+6, ny+6, nz+6)
    metrics = fixtures.f_array(queue, ndim, ndim, nx+6, ny+6, nz+6)
    metric_jacobians = fixtures.f_array(queue, nx+6, ny+6, nz+6)

    flux_derivatives_dev = cl.array.empty(queue, (nvars, ndim, nx+6, ny+6,
        nz+6), dtype=np.float32, order="F")

    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))

    if 1:
        with open("gen-code.cl", "w") as outf:
            outf.write(lp.generate_code_v2(prg).device_code())

    prg = lp.set_options(prg, no_numpy=True)

    prg(queue, nvars=nvars, ndim=ndim,
            states=states, fluxes=fluxes, metrics=metrics,
            metric_jacobians=metric_jacobians,
            flux_derivatives=flux_derivatives_dev)

    prg(queue, nvars=nvars, ndim=ndim,
            states=states, fluxes=fluxes, metrics=metrics,
            metric_jacobians=metric_jacobians,
            flux_derivatives=flux_derivatives_dev)


# This lets you run 'python test.py test_case(cl._csc)' without pytest.
if __name__ == "__main__":
    if len(sys.argv) > 1:
        logging.basicConfig(level="INFO")
        exec(sys.argv[1])
    else:
        from pytest import main
        main([__file__])
