# Copyright (C) 2019 Andreas Kloeckner, Timothy A. Smith
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import numpy as np  # noqa: F401
import numpy.linalg as la  # noqa: F401
import pyopencl as cl
import pyopencl.array  # noqa
import pyopencl.tools  # noqa
import pyopencl.clrandom  # noqa
import loopy as lp  # noqa
import sys

import logging


_WENO_PRG = {}

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def get_queue(ctx_factory):
    return cl.CommandQueue(ctx_factory())


def get_weno_program():
    def parse_weno():
        fn = os.path.join(project_root, "WENO.F90")

        with open(fn, "r") as infile:
            infile_content = infile.read()

        prg = lp.parse_transformed_fortran(infile_content, filename=fn)
        _WENO_PRG["default"] = prg


    if not _WENO_PRG:
        parse_weno()
    return _WENO_PRG["default"]


def transform_weno_for_gpu(prg, print_kernel=False):
    knl = prg["compute_flux_derivatives"]

    for suffix in ["", "_1", "_2", "_3", "_4", "_5"]:
        knl = lp.split_iname(knl, "i"+suffix, 16,
                outer_tag="g.0", inner_tag="l.0")
        knl = lp.split_iname(knl, "j"+suffix, 16,
                outer_tag="g.1", inner_tag="l.1")

    knl = lp.add_barrier(knl, "tag:flux_x_compute", "tag:flux_x_diff")
    knl = lp.add_barrier(knl, "tag:flux_x_diff", "tag:flux_y_compute")
    knl = lp.add_barrier(knl, "tag:flux_y_compute", "tag:flux_y_diff")
    knl = lp.add_barrier(knl, "tag:flux_y_diff", "tag:flux_z_compute")
    knl = lp.add_barrier(knl, "tag:flux_z_compute", "tag:flux_z_diff")

    prg = prg.with_kernel(knl)

    if print_kernel:
        print(prg["convert_to_generalized_frozen"])
        1/0

    return prg


def empty_array_on_device(queue, *shape):
    return cl.array.empty(queue, shape, dtype=np.float64, order="F")


def random_array_on_device(queue, *shape):
    ary = empty_array_on_device(queue, *shape)
    cl.clrandom.fill_rand(ary)
    return ary


def write_target_device_code(prg, outfilename="gen-code.cl"):
    with open(outfilename, "w") as outf:
        outf.write(lp.generate_code_v2(prg).device_code())


def benchmark_compute_flux_derivatives_gpu(ctx_factory, write_code=False):
    logging.basicConfig(level="INFO")

    prg = get_weno_program()
    prg = transform_weno_for_gpu(prg)

    queue = get_queue(ctx_factory)
    prg = prg.copy(target=lp.PyOpenCLTarget(queue.device))
    prg = lp.set_options(prg, no_numpy=True)
    prg = lp.set_options(prg, ignore_boostable_into=True)
    #prg = lp.set_options(prg, write_wrapper=True)
    #op_map = lp.get_op_map(prg, count_redundant_work=False)
    #print(op_map)

    ndim = 3
    nvars = 5
    n = 16*16
    nx = n
    ny = n
    nz = n

    print("ARRAY GEN")
    states = random_array_on_device(queue, nvars, nx+6, ny+6, nz+6)
    fluxes = random_array_on_device(queue, nvars, ndim, nx+6, ny+6, nz+6)
    metrics = random_array_on_device(queue, ndim, ndim, nx+6, ny+6, nz+6)
    metric_jacobians = random_array_on_device(queue, nx+6, ny+6, nz+6)
    print("END ARRAY GEN")

    flux_derivatives_dev = empty_array_on_device(
            queue, nvars, ndim, nx+6, ny+6, nz+6)

    if write_code:
        write_target_device_code(prg)

    allocator = pyopencl.tools.MemoryPool(pyopencl.tools.ImmediateAllocator(queue))

    from functools import partial
    run = partial(prg, queue, nvars=nvars, ndim=ndim,
            states=states, fluxes=fluxes, metrics=metrics,
            metric_jacobians=metric_jacobians,
            flux_derivatives=flux_derivatives_dev,
            allocator=allocator)

    # {{{ monkeypatch enqueue_nd_range_kernel to trace

    if 0:
        old_enqueue_nd_range_kernel = cl.enqueue_nd_range_kernel

        def enqueue_nd_range_kernel_wrapper(queue, ker, *args, **kwargs):
            print(f"Enqueueing {ker.function_name}")
            return old_enqueue_nd_range_kernel(queue, ker, *args, **kwargs)

        cl.enqueue_nd_range_kernel = enqueue_nd_range_kernel_wrapper

    # }}}

    print("warmup")
    for iwarmup_round in range(2):
        run()

    nrounds = 10

    queue.finish()
    print("timing")
    from time import time
    start = time()

    for iround in range(nrounds):
        run()

    queue.finish()
    one_round = (time() - start)/nrounds

    print(f"M RHSs/s: {ndim*nvars*n**3/one_round/1e6}")
    print(f"elapsed per round: {one_round} s")
    print(f"Output size: {flux_derivatives_dev.nbytes/1e6} MB")


if __name__ == "__main__":
    if len(sys.argv) > 1:
        exec(sys.argv[1])
    else:
        benchmark_compute_flux_derivatives_gpu(cl._csc)
