import numpy as np
import numpy.linalg as la  # noqa: F401
import pyopencl as cl
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa
from pytest import approx


# {{{ arrays

def compare_arrays(a, b):
    assert a == approx(b, rel=1e-12, abs=1e-14)


def random_array_on_device(queue, *shape):
    ary = empty_array_on_device(queue, *shape)
    cl.clrandom.fill_rand(ary)
    return ary


def empty_array_on_device(queue, *shape):
    return cl.array.empty(queue, shape, dtype=np.float64, order="F")


def arrays_from_string(string_arrays):
    return split_map_to_list(string_arrays, array_from_string, ":")


def transposed_array_from_string(string_array):
    return array_from_string(string_array).transpose().copy(order="F")


def array_from_string(string_array):
    def array_from_string_1d(string_array):
        if string_array[0] == "i":
            return np.array(split_map_to_list(string_array[1:], int, " "))
        else:
            return np.array(
                    split_map_to_list(string_array, float, " "), dtype=np.float64)

    def array_from_string_2d(string_array):
        if string_array[0] == ",":
            return array_from_string_1d(string_array[1:]).reshape((-1, 1))
        else:
            return np.array(
                    split_map_to_list(string_array, array_from_string_1d, ","))

    def array_from_string_3d(string_array):
        if string_array[0] == ";":
            return array_from_string_1d(string_array[1:]).reshape((-1, 1, 1))
        else:
            return np.array(
                    split_map_to_list(string_array, array_from_string_2d, ";"))

    if ";" not in string_array:
        if "," not in string_array:
            array = array_from_string_1d(string_array)
        else:
            array = array_from_string_2d(string_array)
    else:
        array = array_from_string_3d(string_array)
    return array.copy(order="F")


def split_map_to_list(string, map_func, splitter):
    return list(map(map_func, string.split(splitter)))


def expand_to_6(pair):
    return np.repeat(pair, 3, axis=1).copy(order="F")

# }}}


# {{{ device

def get_queue(ctx_factory):
    ctx = ctx_factory()
    return cl.CommandQueue(ctx)

# }}}


# {{{ program / kernel

_WENO_PRG = {}


def parse_weno():
    fn = "WENO.F90"

    with open(fn, "r") as infile:
        infile_content = infile.read()

    prg = lp.parse_transformed_fortran(infile_content, filename=fn)
    _WENO_PRG["default"] = prg


def get_weno_program():
    if not _WENO_PRG:
        parse_weno()
    return _WENO_PRG["default"]


def with_root_kernel(prg, root_name):
    # FIXME This is a little less beautiful than it could be
    new_prg = prg.copy(name=root_name)
    for name in prg:
        clbl = new_prg[name]
        if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host:
            new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False))

    new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True))
    return new_prg


def setup_weno_program_with_root_kernel(knl):
    prg = get_weno_program()
    prg = with_root_kernel(prg, knl)
    _WENO_PRG[knl] = prg


def get_weno_program_with_root_kernel(knl):
    if knl not in _WENO_PRG:
        setup_weno_program_with_root_kernel(knl)
    return _WENO_PRG[knl]


def transform_weno_for_gpu(prg, print_kernel=False):
    cfd = prg["compute_flux_derivatives"]

    for suffix in ["", "_1", "_2", "_3", "_4", "_5", "_6", "_7"]:
        cfd = lp.split_iname(cfd, "i"+suffix, 16,
                outer_tag="g.0", inner_tag="l.0")
        cfd = lp.split_iname(cfd, "j"+suffix, 16,
                outer_tag="g.1", inner_tag="l.1")

    for var_name in ["delta_xi", "delta_eta", "delta_zeta"]:
        cfd = lp.assignment_to_subst(cfd, var_name)

    cfd = lp.add_barrier(cfd, "tag:to_generalized", "tag:flux_x_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_x_compute", "tag:flux_x_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_x_diff", "tag:flux_y_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_y_compute", "tag:flux_y_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_y_diff", "tag:flux_z_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_z_compute", "tag:flux_z_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_z_diff", "tag:from_generalized")

    prg = prg.with_kernel(cfd)

    # FIXME: These should work, but don't
    # FIXME: Undo the hand-inlining in WENO.F90
    #prg = lp.inline_callable_kernel(prg, "convert_to_generalized")
    #prg = lp.inline_callable_kernel(prg, "convert_from_generalized")

    if print_kernel:
        print(prg["convert_to_generalized_frozen"])
        1/0

    return prg


def write_target_device_code(prg, outfilename="gen-code.cl"):
    with open(outfilename, "w") as outf:
        outf.write(lp.generate_code_v2(prg).device_code())

# }}}

# vim: foldmethod=marker
