import loopy as lp


def weno_for_gpu(prg):
    cfd = prg["compute_flux_derivatives"]

    cfd = lp.assume(cfd, "nx > 0 and ny > 0 and nz > 0")

    cfd = lp.set_temporary_scope(cfd, "flux_derivatives_generalized",
            lp.AddressSpace.GLOBAL)
    cfd = lp.set_temporary_scope(cfd, "generalized_fluxes",
            lp.AddressSpace.GLOBAL)
    cfd = lp.set_temporary_scope(cfd, "weno_flux_tmp",
            lp.AddressSpace.GLOBAL)

    for suffix in ["", "_1", "_2", "_3", "_4", "_5", "_6", "_7"]:
        cfd = lp.split_iname(cfd, "i"+suffix, 16,
                outer_tag="g.0", inner_tag="l.0")
        cfd = lp.split_iname(cfd, "j"+suffix, 16,
                outer_tag="g.1", inner_tag="l.1")

    for var_name in ["delta_xi", "delta_eta", "delta_zeta"]:
        cfd = lp.assignment_to_subst(cfd, var_name)

    cfd = lp.add_barrier(cfd, "tag:to_generalized", "tag:flux_x_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_x_compute", "tag:flux_x_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_x_diff", "tag:flux_y_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_y_compute", "tag:flux_y_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_y_diff", "tag:flux_z_compute")
    cfd = lp.add_barrier(cfd, "tag:flux_z_compute", "tag:flux_z_diff")
    cfd = lp.add_barrier(cfd, "tag:flux_z_diff", "tag:from_generalized")

    prg = prg.with_kernel(cfd)

    # FIXME: These should work, but don't
    # FIXME: Undo the hand-inlining in WENO.F90
    #prg = lp.inline_callable_kernel(prg, "convert_to_generalized")
    #prg = lp.inline_callable_kernel(prg, "convert_from_generalized")

    if 0:
        print(prg["convert_to_generalized_frozen"])
        1/0

    return prg


def get_gpu_transformed_weno():
    import program_fixtures as program
    prg = program.get_weno()
    return weno_for_gpu(prg)
