from __future__ import division, absolute_import

__copyright__ = "Copyright (C) 2012 Andreas Kloeckner"

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""


import six

import islpy as isl
from islpy import dim_type

from loopy.diagnostic import LoopyError


def _find_fusable_loop_domain_index(domain, other_domains):
    my_inames = set(domain.get_var_dict(dim_type.set))

    overlap_domains = []
    for i, o_dom in enumerate(other_domains):
        o_inames = set(o_dom.get_var_dict(dim_type.set))
        if my_inames & o_inames:
            overlap_domains.append(i)

    if len(overlap_domains) >= 2:
        raise LoopyError("more than one domain in one kernel has "
                "overlapping inames with a "
                "domain of the other kernel, cannot fuse: '%s'"
                % domain)

    if len(overlap_domains) == 1:
        return overlap_domains[0]
    else:
        return None


# {{{ generic merge helpers

def _ordered_merge_lists(list_a, list_b):
    result = list_a[:]
    for item in list_b:
        if item not in list_b:
            result.append(item)

    return result


def _merge_dicts(item_name, dict_a, dict_b):
    result = dict_a.copy()

    for k, v in six.iteritems(dict_b):
        if k in result:
            if v != result[k]:
                raise LoopyError("inconsistent %ss for key '%s' in merge: %s and %s"
                        % (item_name, k, v, result[k]))

        else:
            result[k] = v

    return result


def _merge_values(item_name, val_a, val_b):
    if val_a != val_b:
        raise LoopyError("inconsistent %ss in merge: %s and %s"
                % (item_name, val_a, val_b))

    return val_a

# }}}


# {{{ two-kernel fusion

def _fuse_two_kernels(knla, knlb):
    from loopy.kernel import kernel_state
    if knla.state != kernel_state.INITIAL or knlb.state != kernel_state.INITIAL:
        raise LoopyError("can only fuse kernels in INITIAL state")

    # {{{ fuse instructions

    new_instructions = knla.instructions[:]
    from pytools import UniqueNameGenerator
    insn_id_gen = UniqueNameGenerator(
            set([insna.id for insna in new_instructions]))
    for insnb in knlb.instructions:
        new_instructions.append(
                insnb.copy(id=insn_id_gen(insnb.id)))

    # }}}

    # {{{ fuse domains

    new_domains = knla.domains[:]

    for dom_b in knlb.domains:
        i_fuse = _find_fusable_loop_domain_index(dom_b, new_domains)
        if i_fuse is None:
            new_domains.append(dom_b)
        else:
            dom_a = new_domains[i_fuse]
            dom_a, dom_b = isl.align_two(dom_a, dom_b)

            shared_inames = list(
                    set(dom_a.get_var_dict(dim_type.set))
                    &
                    set(dom_b.get_var_dict(dim_type.set)))

            dom_a_s = dom_a.project_out_except(shared_inames, [dim_type.set])
            dom_b_s = dom_a.project_out_except(shared_inames, [dim_type.set])

            if not (dom_a_s <= dom_b_s and dom_b_s <= dom_a_s):
                raise LoopyError("kernels do not agree on domain of "
                        "inames '%s'" % (",".join(shared_inames)))

            new_domain = dom_a & dom_b

            new_domains[i_fuse] = new_domain

    # }}}

    # {{{ fuse args

    new_args = knla.args[:]
    for b_arg in knlb.args:
        if b_arg.name not in knla.arg_dict:
            new_args.append(b_arg)
        else:
            if b_arg != knla.arg_dict[b_arg.name]:
                raise LoopyError(
                        "argument '%s' has inconsistent definition between "
                        "the two kernels being merged" % b_arg.name)

    # }}}

    # {{{ fuse assumptions

    assump_a = knla.assumptions
    assump_b = knlb.assumptions
    assump_a, assump_b = isl.align_two(assump_a, assump_b)

    shared_param_names = list(
            set(dom_a.get_var_dict(dim_type.set))
            &
            set(dom_b.get_var_dict(dim_type.set)))

    assump_a_s = assump_a.project_out_except(shared_param_names, [dim_type.param])
    assump_b_s = assump_a.project_out_except(shared_param_names, [dim_type.param])

    if not (assump_a_s <= assump_b_s and assump_b_s <= assump_a_s):
        raise LoopyError("assumptions do not agree on kernels to be merged")

    new_assumptions = (assump_a & assump_b).params()

    # }}}

    from loopy.kernel import LoopKernel
    return LoopKernel(
            domains=new_domains,
            instructions=new_instructions,
            args=new_args,
            name="%s_and_%s" % (knla.name, knlb.name),
            preambles=_ordered_merge_lists(knla.preambles, knlb.preambles),
            preamble_generators=_ordered_merge_lists(
                knla.preamble_generators, knlb.preamble_generators),
            assumptions=new_assumptions,
            local_sizes=_merge_dicts(
                "local size", knla.local_sizes, knlb.local_sizes),
            temporary_variables=_merge_dicts(
                "temporary variable",
                knla.temporary_variables,
                knlb.temporary_variables),
            iname_to_tag=_merge_dicts(
                "iname-to-tag mapping",
                knla.iname_to_tag,
                knlb.iname_to_tag),
            substitutions=_merge_dicts(
                "substitution",
                knla.substitutions,
                knlb.substitutions),
            function_manglers=_ordered_merge_lists(
                knla.function_manglers,
                knlb.function_manglers),
            symbol_manglers=_ordered_merge_lists(
                knla.symbol_manglers,
                knlb.symbol_manglers),

            iname_slab_increments=_merge_dicts(
                "iname slab increment",
                knla.iname_slab_increments,
                knlb.iname_slab_increments),
            loop_priority=_ordered_merge_lists(
                knla.loop_priority,
                knlb.loop_priority),
            silenced_warnings=_ordered_merge_lists(
                knla.silenced_warnings,
                knlb.silenced_warnings),
            applied_iname_rewrites=_ordered_merge_lists(
                knla.applied_iname_rewrites,
                knlb.applied_iname_rewrites),
            index_dtype=_merge_values(
                "index dtype",
                knla.index_dtype,
                knlb.index_dtype),
            target=_merge_values(
                "target",
                knla.target,
                knlb.target),
            options=knla.options)

# }}}


def fuse_kernels(kernels):
    kernels = list(kernels)
    result = kernels.pop(0)
    while kernels:
        result = _fuse_two_kernels(result, kernels.pop(0))

    return result

# vim: foldmethod=marker
