import numpy as np
import numpy.linalg as la
import pyopencl as cl
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa
from pytest import approx


### Arrays ###

def compare_arrays(a, b):
    assert a == approx(b)


def random_array_on_device(queue, *shape):
    empty = empty_array_on_device(queue, shape)
    return cl.clrandom.fill_rand(empty)


def empty_array_on_device(queue, shape):
    return cl.array.empty(queue, shape, dtype=np.float32, order="F")


def arrays_from_string(string_arrays):
    return split_map_to_list(string_arrays, array_from_string, ":")


def array_from_string(string_array):
    def array_from_string_1d(string_array):
        if string_array[0] == "i":
            return np.array(split_map_to_list(string_array[1:], int, " "))
        else:
            return np.array(split_map_to_list(string_array, float, " "), dtype=np.float32)

    def array_from_string_2d(string_array):
        if string_array[0] == ",":
            return array_from_string_1d(string_array[1:]).reshape((-1, 1))
        else:
            return np.array(split_map_to_list(string_array, array_from_string_1d, ","))

    def array_from_string_3d(string_array):
        if string_array[0] == ";":
            return array_from_string_1d(string_array[1:]).reshape((-1, 1, 1))
        else:
            return np.array(split_map_to_list(string_array, array_from_string_2d, ";"))

    if ";" not in string_array:
        if "," not in string_array:
            array = array_from_string_1d(string_array)
        else:
            array = array_from_string_2d(string_array)
    else:
        array = array_from_string_3d(string_array)
    return array.copy(order="F")


def split_map_to_list(string, map_func, splitter):
    return list(map(map_func, string.split(splitter)))


### Device ###

_QUEUE = []


def get_queue(ctx_factory):
    if not _QUEUE:
        setup_queue(ctx_factory)
    return _QUEUE[0]


def setup_queue(ctx_factory):
    ctx = ctx_factory()
    _QUEUE.append(cl.CommandQueue(ctx))


### Program / Kernel ###

_WENO_PRG = []


def get_weno_program():
    if not _WENO_PRG:
        parse_weno()
    return _WENO_PRG[0]


def parse_weno():
    fn = "WENO.F90"

    with open(fn, "r") as infile:
        infile_content = infile.read()

    prg = lp.parse_transformed_fortran(infile_content, filename=fn)
    _WENO_PRG.append(prg)


def with_root_kernel(prg, root_name):
    # FIXME This is a little less beautiful than it could be
    new_prg = prg.copy(name=root_name)
    for name in prg:
        clbl = new_prg[name]
        if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host:
            new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False))

    new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True))
    return new_prg


def transform_compute_flux_derivative_gpu(queue, prg):
    prg = transform_weno_for_gpu(prg)

    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))

    if 1:
        with open("gen-code.cl", "w") as outf:
            outf.write(lp.generate_code_v2(prg).device_code())

    prg = lp.set_options(prg, no_numpy=True)
    return prg


def transform_weno_for_gpu(prg):
    prg = transform_compute_flux_derivative_basic(prg)

    cfd = prg["compute_flux_derivatives"]

    for suffix in ["", "_1", "_2", "_3", "_4", "_5", "_6", "_7"]:
        cfd = lp.split_iname(cfd, "i"+suffix, 16,
                outer_tag="g.0", inner_tag="l.0")
        cfd = lp.split_iname(cfd, "j"+suffix, 16,
                outer_tag="g.1", inner_tag="l.1")

    for var_name in ["delta_xi", "delta_eta", "delta_zeta"]:
        cfd = lp.assignment_to_subst(cfd, var_name)

    cfd = lp.add_barrier(cfd, "tag:to_generalized", "tag:flux_x_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_x_compute", "tag:flux_x_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_x_diff", "tag:flux_y_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_y_compute", "tag:flux_y_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_y_diff", "tag:flux_z_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_z_compute", "tag:flux_z_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_z_diff", "tag:from_generalized")

    prg = prg.with_kernel(cfd)

    # FIXME: These should work, but don't
    # FIXME: Undo the hand-inlining in WENO.F90
    #prg = lp.inline_callable_kernel(prg, "convert_to_generalized")
    #prg = lp.inline_callable_kernel(prg, "convert_from_generalized")

    if 0:
        print(prg["convert_to_generalized_frozen"])
        1/0

    return prg


def transform_compute_flux_derivative_basic(prg):
    cfd = prg["compute_flux_derivatives"]

    cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0")

    cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized",
            lp.AddressSpace.GLOBAL)
    cfd = lp.set_temporary_scope(cfd, "generalized_fluxes",
            lp.AddressSpace.GLOBAL)
    cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp",
            lp.AddressSpace.GLOBAL)

    return prg.with_kernel(cfd)


