import numpy as np
import numpy.linalg as la  # noqa: F401
import pyopencl as cl  # noqa: F401
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa

import sys
import logging

import pytest
from pyopencl.tools import (  # noqa
        pytest_generate_tests_for_pyopencl
        as pytest_generate_tests)

import utilities as u
from data_for_test import flux_test_data_fixture  # noqa: F401


def test_weno_weight_computation(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture


@pytest.mark.slow
def test_weno_flux_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("weno_flux")
    queue = u.get_queue(ctx_factory)

    flux_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            generalized_fluxes=data.fluxes,
            characteristic_fluxes_pos=data.char_fluxes_pos,
            characteristic_fluxes_neg=data.char_fluxes_neg,
            combined_frozen_metrics=1.0,
            R=data.R,
            flux=flux_dev)

    u.compare_arrays(flux_dev.get(), data.weno_flux)


@pytest.mark.slow
def test_consistent_part_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("consistent_part")
    queue = u.get_queue(ctx_factory)

    consistent_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            generalized_fluxes=data.fluxes,
            consistent=consistent_dev)

    u.compare_arrays(consistent_dev.get(), data.consistent)


@pytest.mark.slow
def test_dissipation_part_pos_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("dissipation_part_pos")
    queue = u.get_queue(ctx_factory)

    dissipation_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_pos,
            combined_frozen_metrics=1.0,
            R=data.R,
            dissipation_pos=dissipation_dev)

    u.compare_arrays(dissipation_dev.get(), data.dissipation_pos)


@pytest.mark.slow
def test_dissipation_part_neg_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("dissipation_part_neg")
    queue = u.get_queue(ctx_factory)

    dissipation_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_neg,
            combined_frozen_metrics=1.0,
            R=data.R,
            dissipation_neg=dissipation_dev)

    u.compare_arrays(dissipation_dev.get(), data.dissipation_neg)


@pytest.mark.slow
def test_weno_weights_pos_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("weno_weights_pos")
    queue = u.get_queue(ctx_factory)

    weights_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_pos,
            combined_frozen_metrics=1.0,
            w=weights_dev)

    sum_weights = np.sum(weights_dev.get(), axis=1)
    u.compare_arrays(sum_weights, np.ones(data.nvars))

    u.compare_arrays(weights_dev.get(), data.weno_weights_pos)


@pytest.mark.slow
def test_weno_weights_neg_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("weno_weights_neg")
    queue = u.get_queue(ctx_factory)

    weights_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_neg,
            combined_frozen_metrics=1.0,
            w=weights_dev)

    sum_weights = np.sum(weights_dev.get(), axis=1)
    u.compare_arrays(sum_weights, np.ones(data.nvars))

    u.compare_arrays(weights_dev.get(), data.weno_weights_neg)


@pytest.mark.slow
def test_oscillation_pos_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("oscillation_pos")
    queue = u.get_queue(ctx_factory)

    oscillation_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_pos,
            oscillation=oscillation_dev)

    u.compare_arrays(oscillation_dev.get(), data.oscillation_pos)


@pytest.mark.slow
def test_oscillation_neg_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("oscillation_neg")
    queue = u.get_queue(ctx_factory)

    oscillation_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_neg,
            oscillation=oscillation_dev)

    u.compare_arrays(oscillation_dev.get(), data.oscillation_neg)


@pytest.mark.slow
def test_flux_splitting_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("split_characteristic_fluxes")
    queue = u.get_queue(ctx_factory)

    fluxes_pos_dev = u.empty_array_on_device(queue, data.nvars, 6)
    fluxes_neg_dev = u.empty_array_on_device(queue, data.nvars, 6)

    prg(queue, nvars=data.nvars,
            generalized_states_frozen=data.states,
            generalized_fluxes_frozen=data.fluxes,
            R_inv=data.R_inv,
            wavespeeds=data.wavespeeds,
            characteristic_fluxes_pos=fluxes_pos_dev,
            characteristic_fluxes_neg=fluxes_neg_dev)

    u.compare_arrays(fluxes_pos_dev.get(), data.char_fluxes_pos)
    u.compare_arrays(fluxes_neg_dev.get(), data.char_fluxes_neg)


@pytest.mark.slow
def test_pointwise_eigenvalues_ideal_gas(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("pointwise_eigenvalues")
    queue = u.get_queue(ctx_factory)

    lam_dev = u.empty_array_on_device(queue, data.nvars, 6)

    prg(queue, nvars=data.nvars, d=data.direction,
            states=data.states, lambda_pointwise=lam_dev)

    u.compare_arrays(lam_dev.get(), data.lam_pointwise)


@pytest.mark.slow
def test_roe_uniform_grid_ideal_gas(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    def identity_matrix(n):
        return np.identity(n).astype(np.float64).copy(order="F")

    def check_roe_identity(states, R, R_inv):
        d_state = states[:, 1] - states[:, 0]
        u.compare_arrays(R@(R_inv@d_state), d_state)

    def check_roe_property(states, fluxes, R, R_inv, lam):
        d_state = states[:, 1] - states[:, 0]
        d_flux = fluxes[:, 1] - fluxes[:, 0]

        temp = R_inv@d_state
        temp = np.multiply(lam, temp)
        u.compare_arrays(R@temp, d_flux)

    prg = u.get_weno_program_with_root_kernel("roe_eigensystem")
    queue = u.get_queue(ctx_factory)

    metrics_frozen = identity_matrix(data.ndim)

    R_dev = u.empty_array_on_device(queue, data.nvars, data.nvars)
    R_inv_dev = u.empty_array_on_device(queue, data.nvars, data.nvars)
    lam_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars, ndim=data.ndim, d=data.direction,
            states=data.state_pair, metrics_frozen=metrics_frozen,
            R=R_dev, R_inv=R_inv_dev, lambda_roe=lam_dev)

    R = R_dev.get()
    R_inv = R_inv_dev.get()
    lam = lam_dev.get()

    check_roe_identity(data.state_pair, R, R_inv)
    check_roe_property(data.state_pair, data.flux_pair, R, R_inv, lam)


@pytest.mark.slow
@pytest.mark.parametrize("lam_pointwise_str,lam_roe_str,lam_expected_str", [
    ("1 2 3 4 5,2 4 6 8 10", "1.5 3 4.5 6 7.5", "2.2 4.4 6.6 8.8 11"),
    ("1 2 3 4 5,-2 -4 -6 -8 -10", "1.5 3 4.5 6 7.5", "2.2 4.4 6.6 8.8 11"),
    ("1 2 3 4 5,-2 -4 -6 -8 -10", "3 6 9 12 15", "3.3 6.6 9.9 13.2 16.5"),
    ("1 2 3 4 5,2 4 6 8 10", "-3 -6 -9 -12 -15", "3.3 6.6 9.9 13.2 16.5"),
    ("3 2 9 4 5,2 6 6 12 10", "-1 -4 -3 -8 -15", "3.3 6.6 9.9 13.2 16.5")
    ])
def test_lax_wavespeeds(
        ctx_factory, lam_pointwise_str, lam_roe_str, lam_expected_str):
    prg = u.get_weno_program_with_root_kernel("lax_wavespeeds")
    queue = u.get_queue(ctx_factory)

    nvars = 5

    lam_pointwise = u.expand_to_6(u.transposed_array_from_string(lam_pointwise_str))
    lam_roe = u.array_from_string(lam_roe_str)
    lam_dev = u.empty_array_on_device(queue, nvars)

    prg(queue, nvars=nvars, lambda_pointwise=lam_pointwise,
            lambda_roe=lam_roe, wavespeeds=lam_dev)

    lam_expected = u.array_from_string(lam_expected_str)
    u.compare_arrays(lam_dev.get(), lam_expected)


@pytest.mark.slow
def test_matvec(ctx_factory):
    prg = u.get_weno_program_with_root_kernel("mult_mat_vec")
    queue = u.get_queue(ctx_factory)

    a = u.random_array_on_device(queue, 10, 10)
    b = u.random_array_on_device(queue, 10)

    c = u.empty_array_on_device(queue, 10)

    prg(queue, alpha=1.0, a=a, b=b, c=c)

    u.compare_arrays(a.get()@b.get(), c.get())


@pytest.mark.slow
def test_compute_flux_derivatives(ctx_factory):
    prg = u.get_weno_program()

    queue = u.get_queue(ctx_factory)
    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))

    lp.auto_test_vs_ref(prg, ctx_factory(), warmup_rounds=1,
            parameters=dict(ndim=3, nvars=5, nx=16, ny=16, nz=16))


@pytest.mark.slow
def test_compute_flux_derivatives_gpu(ctx_factory, write_code=False):
    prg = u.get_weno_program()
    prg = u.transform_weno_for_gpu(prg)

    queue = u.get_queue(ctx_factory)
    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
    prg = lp.set_options(prg, no_numpy=True)

    if write_code:
        u.write_target_device_code(prg)

    lp.auto_test_vs_ref(prg, ctx_factory(), warmup_rounds=1,
            parameters=dict(ndim=3, nvars=5, nx=16, ny=16, nz=16))


# This lets you run 'python test.py test_case(cl._csc)' without pytest.
if __name__ == "__main__":
    if len(sys.argv) > 1:
        logging.basicConfig(level="INFO")
        exec(sys.argv[1])
    else:
        pytest.main([__file__])
