import numpy as np
import pyopencl as cl
import pyopencl.array  # noqa
import loopy as lp  # noqa

import device_fixtures as device
import program_fixtures as program
import transform_fixtures as transform
import setup_fixtures as setup


def with_root_kernel(prg, root_name):
    # FIXME This is a little less beautiful than it could be
    new_prg = prg.copy(name=root_name)
    for name in prg:
        clbl = new_prg[name]
        if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host:
            new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False))

    new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True))
    return new_prg


def mult_mat_vec(queue, prg, alpha, a, b):
    c_dev = setup.empty_array_on_device(queue, b.shape)

    prg = with_root_kernel(prg, "mult_mat_vec")
    prg(queue, a=a, b=b, c=c_dev, alpha=alpha)

    return c_dev.get()


def compute_flux_derivatives(queue, prg, params, arrays):
    flux_derivatives_dev = setup.empty_array_on_device(queue, (params.nvars, params.ndim,
        params.nx_halo, params.ny_halo, params.nz_halo))

    prg(queue, nvars=params.nvars, ndim=params.ndim,
            states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
            metric_jacobians=arrays.metric_jacobians,
            flux_derivatives=flux_derivatives_dev)

    return flux_derivatives_dev.get()

def compute_flux_derivatives_gpu(ctx_factory, params, arrays):
    prg = transform.get_gpu_transformed_weno()

    queue = device.get_queue(ctx_factory)

    flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
        params.nx_halo, params.ny_halo, params.nz_halo), dtype=np.float32, order="F")

    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))

    if 1:
        with open("gen-code.cl", "w") as outf:
            outf.write(lp.generate_code_v2(prg).device_code())

    prg = lp.set_options(prg, no_numpy=True)

    prg(queue, nvars=params.nvars, ndim=params.ndim,
            states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
            metric_jacobians=arrays.metric_jacobians,
            flux_derivatives=flux_derivatives_dev)
