import numpy as np
import pyopencl as cl
import loopy as lp  # noqa

import fixtures


def mult_mat_vec(ctx_factory, alpha, a, b):
    queue = fixtures.get_queue(ctx_factory)

    c_dev = cl.array.empty(queue, 10, dtype=np.float32)

    prg = fixtures.with_root_kernel(fixtures.get_weno_program(), "mult_mat_vec")
    prg(queue, a=a, b=b, c=c_dev, alpha=alpha)

    return c_dev.get()


def compute_flux_derivatives(ctx_factory, params, arrays):
    queue = fixtures.get_queue(ctx_factory)

    prg = fixtures.get_weno_program()
    cfd = prg["compute_flux_derivatives"]

    cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0")

    cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized",
            lp.AddressSpace.GLOBAL)
    cfd = lp.set_temporary_scope(cfd, "generalized_fluxes",
            lp.AddressSpace.GLOBAL)
    cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp",
            lp.AddressSpace.GLOBAL)

    prg = prg.with_kernel(cfd)

    nx_halo = params.nx + params.nhalo
    ny_halo = params.ny + params.nhalo
    nz_halo = params.nz + params.nhalo

    flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
        nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F")

    prg(queue, nvars=params.nvars, ndim=params.ndim,
            states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
            metric_jacobians=arrays.metric_jacobians,
            flux_derivatives=flux_derivatives_dev)
    return flux_derivatives_dev.get()

def compute_flux_derivatives_gpu(ctx_factory, params, arrays):
    prg = fixtures.get_gpu_transformed_weno()

    queue = fixtures.get_queue(ctx_factory)

    nx_halo = params.nx + params.nhalo
    ny_halo = params.ny + params.nhalo
    nz_halo = params.nz + params.nhalo

    flux_derivatives_dev = cl.array.empty(queue, (params.nvars, params.ndim,
        nx_halo, ny_halo, nz_halo), dtype=np.float32, order="F")

    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))

    if 1:
        with open("gen-code.cl", "w") as outf:
            outf.write(lp.generate_code_v2(prg).device_code())

    prg = lp.set_options(prg, no_numpy=True)

    prg(queue, nvars=params.nvars, ndim=params.ndim,
            states=arrays.states, fluxes=arrays.fluxes, metrics=arrays.metrics,
            metric_jacobians=arrays.metric_jacobians,
            flux_derivatives=flux_derivatives_dev)
