import loopy as lp  # noqa

import setup_fixtures as setup


def with_root_kernel(prg, root_name):
    # FIXME This is a little less beautiful than it could be
    new_prg = prg.copy(name=root_name)
    for name in prg:
        clbl = new_prg[name]
        if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host:
            new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False))

    new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True))
    return new_prg


def roe_eigensystem(queue, prg, params, states, metrics_frozen):
    R_dev = setup.empty_array_on_device(queue, params.mat_bounds())
    Rinv_dev = setup.empty_array_on_device(queue, params.mat_bounds())
    lam_dev = setup.empty_array_on_device(queue, params.vec_bounds())

    prg = with_root_kernel(prg, "roe_eigensystem")
    prg(queue, nvars=params.nvars, ndim=params.ndim, d=params.d,
            states=states, metrics_frozen=metrics_frozen,
            R=R_dev, R_inv=Rinv_dev, lambda_roe=lam_dev)

    return R_dev.get(), Rinv_dev.get(), lam_dev.get()


def mult_mat_vec(queue, prg, alpha, a, b):
    c_dev = setup.empty_array_on_device(queue, b.shape)

    prg = with_root_kernel(prg, "mult_mat_vec")
    prg(queue, a=a, b=b, c=c_dev, alpha=alpha)

    return c_dev.get()


def compute_flux_derivatives(queue, prg, params, arrays):
    flux_derivatives_dev = setup.empty_array_on_device(queue, (params.nvars, params.ndim,
        params.nx_halo, params.ny_halo, params.nz_halo))

    prg(queue, nvars=params.nvars, ndim=params.ndim,
            states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
            metric_jacobians=arrays.metric_jacobians,
            flux_derivatives=flux_derivatives_dev)

    return flux_derivatives_dev.get()
