import numpy as np
import numpy.linalg as la  # noqa: F401
import pyopencl as cl  # noqa: F401
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa

import sys
import logging

import pytest
from pyopencl.tools import (  # noqa
        pytest_generate_tests_for_pyopencl
        as pytest_generate_tests)

import utilities as u
from data_for_test import (  # noqa: F401
        flux_test_data_fixture,
        single_data as sd
        )


def test_weno_flux_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("weno_flux")
    queue = u.get_queue(ctx_factory)

    flux_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            generalized_fluxes=data.fluxes,
            characteristic_fluxes_pos=data.char_fluxes_pos,
            characteristic_fluxes_neg=data.char_fluxes_neg,
            combined_frozen_metrics=data.combined_frozen_metrics,
            R=data.R,
            flux=flux_dev)

    u.compare_arrays(flux_dev.get(), data.weno_flux)


def test_consistent_part_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("consistent_part")
    queue = u.get_queue(ctx_factory)

    consistent_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            generalized_fluxes=data.fluxes,
            consistent=consistent_dev)

    u.compare_arrays(consistent_dev.get(), data.consistent)


def test_dissipation_part_pos_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("dissipation_part_pos")
    queue = u.get_queue(ctx_factory)

    dissipation_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_pos,
            combined_frozen_metrics=data.combined_frozen_metrics,
            R=data.R,
            dissipation_pos=dissipation_dev)

    u.compare_arrays(dissipation_dev.get(), data.dissipation_pos)


def test_dissipation_part_neg_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("dissipation_part_neg")
    queue = u.get_queue(ctx_factory)

    dissipation_dev = u.empty_array_on_device(queue, data.nvars)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_neg,
            combined_frozen_metrics=data.combined_frozen_metrics,
            R=data.R,
            dissipation_neg=dissipation_dev)

    u.compare_arrays(dissipation_dev.get(), data.dissipation_neg)


def test_weno_weights_pos_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("weno_weights_pos")
    queue = u.get_queue(ctx_factory)

    weights_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_pos,
            combined_frozen_metrics=data.combined_frozen_metrics,
            w=weights_dev)

    sum_weights = np.sum(weights_dev.get(), axis=1)
    u.compare_arrays(sum_weights, np.ones(data.nvars))

    u.compare_arrays(weights_dev.get(), data.weno_weights_pos)


def test_weno_weights_neg_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("weno_weights_neg")
    queue = u.get_queue(ctx_factory)

    weights_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_neg,
            combined_frozen_metrics=data.combined_frozen_metrics,
            w=weights_dev)

    sum_weights = np.sum(weights_dev.get(), axis=1)
    u.compare_arrays(sum_weights, np.ones(data.nvars))

    u.compare_arrays(weights_dev.get(), data.weno_weights_neg)


def test_oscillation_pos_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("oscillation_pos")
    queue = u.get_queue(ctx_factory)

    oscillation_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_pos,
            oscillation=oscillation_dev)

    u.compare_arrays(oscillation_dev.get(), data.oscillation_pos)


def test_oscillation_neg_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("oscillation_neg")
    queue = u.get_queue(ctx_factory)

    oscillation_dev = u.empty_array_on_device(queue, data.nvars, 3)

    prg(queue, nvars=data.nvars,
            characteristic_fluxes=data.char_fluxes_neg,
            oscillation=oscillation_dev)

    u.compare_arrays(oscillation_dev.get(), data.oscillation_neg)


def test_flux_splitting_uniform_grid(ctx_factory, flux_test_data_fixture):
    data = flux_test_data_fixture

    prg = u.get_weno_program_with_root_kernel("split_characteristic_fluxes")
    queue = u.get_queue(ctx_factory)

    fluxes_pos_dev = u.empty_array_on_device(queue, data.nvars, 6)
    fluxes_neg_dev = u.empty_array_on_device(queue, data.nvars, 6)

    prg(queue, nvars=data.nvars,
            generalized_states_frozen=data.states,
            generalized_fluxes_frozen=data.fluxes,
            R_inv=data.R_inv,
            wavespeeds=data.wavespeeds,
            characteristic_fluxes_pos=fluxes_pos_dev,
            characteristic_fluxes_neg=fluxes_neg_dev)

    u.compare_arrays(fluxes_pos_dev.get(), data.char_fluxes_pos)
    u.compare_arrays(fluxes_neg_dev.get(), data.char_fluxes_neg)


# This lets you run 'python test.py test_case(cl._csc)' without pytest.
if __name__ == "__main__":
    if len(sys.argv) > 1:
        logging.basicConfig(level="INFO")
        exec(sys.argv[1])
    else:
        pytest.main([__file__])