import numpy as np  # noqa: F401
import numpy.linalg as la  # noqa: F401
import pyopencl as cl
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa
import sys

import logging

import utilities as u


def benchmark_compute_flux_derivatives_gpu(ctx_factory, write_code=False):
    logging.basicConfig(level="INFO")

    prg = u.get_weno_program()
    prg = u.transform_weno_for_gpu(prg)

    queue = u.get_queue(ctx_factory)
    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
    prg = lp.set_options(prg, no_numpy=True)
    prg = lp.set_options(prg, ignore_boostable_into=True)
    #prg = lp.set_options(prg, write_wrapper=True)
    #op_map = lp.get_op_map(prg, count_redundant_work=False)
    #print(op_map)

    ndim = 3
    nvars = 5
    n = 16*16
    nx = n
    ny = n
    nz = n

    print("ARRAY GEN")
    states = u.random_array_on_device(queue, nvars, nx+6, ny+6, nz+6)
    fluxes = u.random_array_on_device(queue, nvars, ndim, nx+6, ny+6, nz+6)
    metrics = u.random_array_on_device(queue, ndim, ndim, nx+6, ny+6, nz+6)
    metric_jacobians = u.random_array_on_device(queue, nx+6, ny+6, nz+6)
    print("END ARRAY GEN")

    flux_derivatives_dev = u.empty_array_on_device(
            queue, nvars, ndim, nx+6, ny+6, nz+6)

    if write_code:
        u.write_target_device_code(prg)

    allocator = pyopencl.tools.MemoryPool(pyopencl.tools.ImmediateAllocator(queue))

    from functools import partial
    run = partial(prg, queue, nvars=nvars, ndim=ndim,
            states=states, fluxes=fluxes, metrics=metrics,
            metric_jacobians=metric_jacobians,
            flux_derivatives=flux_derivatives_dev,
            allocator=allocator)

    # {{{ monkeypatch enqueue_nd_range_kernel to trace

    if 0:
        old_enqueue_nd_range_kernel = cl.enqueue_nd_range_kernel

        def enqueue_nd_range_kernel_wrapper(queue, ker, *args, **kwargs):
            print(f"Enqueueing {ker.function_name}")
            return old_enqueue_nd_range_kernel(queue, ker, *args, **kwargs)

        cl.enqueue_nd_range_kernel = enqueue_nd_range_kernel_wrapper

    # }}}

    print("warmup")
    for iwarmup_round in range(2):
        run()

    nrounds = 10

    queue.finish()
    print("timing")
    from time import time
    start = time()

    for iround in range(nrounds):
        run()

    queue.finish()
    one_round = (time() - start)/nrounds

    print(f"M RHSs/s: {ndim*nvars*n**3/one_round/1e6}")
    print(f"elapsed per round: {one_round} s")
    print(f"Output size: {flux_derivatives_dev.nbytes/1e6} MB")


if __name__ == "__main__":
    if len(sys.argv) > 1:
        exec(sys.argv[1])
    else:
        benchmark_compute_flux_derivatives_gpu(cl._csc)