import numpy as np
import numpy.linalg as la  # noqa: F401
import pyopencl as cl  # noqa: F401
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa

import sys
import logging

import pytest

import utilities as u
from data_for_test import (  # noqa: F401
        cfd_test_data_fixture
        )


def test_compute_flux_derivatives_uniform_grid(queue, cfd_test_data_fixture):
    data = cfd_test_data_fixture

    prg = u.get_weno_program()
    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))

    flux_derivatives_dev = u.empty_array_on_device(queue, *data.flux_dims)

    prg(queue,
            nvars=data.nvars,
            ndim=data.ndim,
            nx=data.nx,
            ny=data.ny,
            nz=data.nz,
            d=data.direction,
            states=data.states,
            fluxes=data.fluxes,
            metrics=data.metrics,
            metric_jacobians=data.jacobians,
            flux_derivatives=flux_derivatives_dev)

    u.compare_arrays(flux_derivatives_dev.get(), data.flux_derivatives)


@pytest.mark.slow
def test_compute_flux_derivatives(ctx_factory):
    prg = u.get_weno_program()

    queue = u.get_queue(ctx_factory)
    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))

    lp.auto_test_vs_ref(prg, ctx_factory(), warmup_rounds=0,
            parameters=dict(ndim=3, nvars=5, nx=16, ny=16, nz=16, d=0))


@pytest.mark.slow
def test_compute_flux_derivatives_gpu(ctx_factory, write_code=False):
    prg = u.get_weno_program()
    prg = u.transform_weno_for_gpu(prg)

    queue = u.get_queue(ctx_factory)
    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
    prg = lp.set_options(prg, no_numpy=True)

    if write_code:
        u.write_target_device_code(prg)

    lp.auto_test_vs_ref(prg, ctx_factory(), warmup_rounds=0,
            parameters=dict(ndim=3, nvars=5, nx=16, ny=16, nz=16, d=0))


# This lets you run 'python test.py test_case(cl._csc)' without pytest.
if __name__ == "__main__":
    if len(sys.argv) > 1:
        logging.basicConfig(level="INFO")
        exec(sys.argv[1])
    else:
        pytest.main([__file__])
