import numpy as np
import numpy.linalg as la  # noqa
import pyopencl as cl
import pyopencl.array  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp

from pytest import approx


def with_root_kernel(prg, root_name):
    # FIXME This is a little less beautiful than it could be
    new_prg = prg.copy(name=root_name)
    for name in prg:
        clbl = new_prg[name]
        if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host:
            new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False))

    new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True))
    return new_prg


class LoopyFixture:
    _WENO_PRG = []
    _QUEUE = []

    def __init__(self):
        self.prg = self.get_weno_program()

    def get_weno_program(self):
        if self._WENO_PRG:
            return self._WENO_PRG[0]

        fn = "WENO.F90"

        with open(fn, "r") as infile:
            infile_content = infile.read()

        prg = lp.parse_transformed_fortran(infile_content, filename=fn)
        self._WENO_PRG.append(prg)
        return prg

    def get_queue(self, ctx_factory):
        if not self._QUEUE:
            ctx = ctx_factory()
            self._QUEUE.append(cl.CommandQueue(ctx))
        return self._QUEUE[0]

    def random_array(self, *args):
        return np.random.random_sample(args).astype(np.float32).copy(order="F")

    def mult_mat_vec(self, ctx_factory, alpha, a, b):
        queue = self.get_queue(ctx_factory)

        c_dev = cl.array.empty(queue, 10, dtype=np.float32)

        prg = with_root_kernel(self.prg, "mult_mat_vec")
        prg(queue, a=a, b=b, c=c_dev, alpha=alpha)

        return c_dev.get()

    def compute_flux_derivatives(self, ctx_factory,
            nvars, ndim, nx, ny, nz,
            states, fluxes, metrics, metric_jacobians):

        queue = self.get_queue(ctx_factory)

        prg = self.prg
        cfd = prg["compute_flux_derivatives"]

        cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0")

        cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized",
                lp.AddressSpace.GLOBAL)
        cfd = lp.set_temporary_scope(cfd, "generalized_fluxes",
                lp.AddressSpace.GLOBAL)
        cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp",
                lp.AddressSpace.GLOBAL)

        prg = prg.with_kernel(cfd)

        flux_derivatives_dev = cl.array.empty(queue, (nvars, ndim, nx+6, ny+6,
            nz+6), dtype=np.float32, order="F")

        prg(queue, nvars=nvars, ndim=ndim,
                states=states, fluxes=fluxes, metrics=metrics,
                metric_jacobians=metric_jacobians,
                flux_derivatives=flux_derivatives_dev)
        return flux_derivatives_dev.get()
