"""Pymbolic mappers for loopy."""

from __future__ import division

__copyright__ = "Copyright (C) 2012 Andreas Kloeckner"

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""




from pytools import memoize, memoize_method
import pytools.lex

from pymbolic.primitives import (
        Leaf, AlgebraicLeaf, Variable as VariableBase,
        CommonSubexpression)

from pymbolic.mapper import (
        CombineMapper as CombineMapperBase,
        IdentityMapper as IdentityMapperBase,
        RecursiveMapper,
        WalkMapper as WalkMapperBase,
        CallbackMapper as CallbackMapperBase,
        )
from pymbolic.mapper.substitutor import \
        SubstitutionMapper as SubstitutionMapperBase
from pymbolic.mapper.stringifier import \
        StringifyMapper as StringifyMapperBase
from pymbolic.mapper.dependency import \
        DependencyMapper as DependencyMapperBase
from pymbolic.mapper.unifier import UnidirectionalUnifier \
        as UnidirectionalUnifierBase

from pymbolic.parser import Parser as ParserBase

import islpy as isl
from islpy import dim_type




# {{{ loopy-specific primitives

class FunctionIdentifier(Leaf):
    def __getinitargs__(self):
        return ()

    mapper_method = intern("map_loopy_function_identifier")

class TypedCSE(CommonSubexpression):
    def __init__(self, child, prefix=None, dtype=None):
        CommonSubexpression.__init__(self, child, prefix)
        self.dtype = dtype

    def __getinitargs__(self):
        return (self.child, self.dtype, self.prefix)

    def get_extra_properties(self):
        return dict(dtype=self.dtype)


class TaggedVariable(VariableBase):
    """This is an identifier with a tag, such as 'matrix$one', where
    'one' identifies this specific use of the identifier. This mechanism
    may then be used to address these uses--such as by prefetching only
    accesses tagged a certain way.
    """
    def __init__(self, name, tag):
        VariableBase.__init__(self, name)
        self.tag = tag

    def __getinitargs__(self):
        return self.name, self.tag

    def stringifier(self):
        return StringifyMapper

    mapper_method = intern("map_tagged_variable")

class Reduction(AlgebraicLeaf):
    def __init__(self, operation, inames, expr):
        assert isinstance(inames, tuple)

        if isinstance(operation, str):
            from loopy.reduction import parse_reduction_op
            operation = parse_reduction_op(operation)

        self.operation = operation
        self.inames = inames
        self.expr = expr

    def __getinitargs__(self):
        return (self.operation, self.inames, self.expr)

    def get_hash(self):
        return hash((self.__class__, self.operation, self.inames,
            self.expr))

    def is_equal(self, other):
        return (other.__class__ == self.__class__
                and other.operation == self.operation
                and other.inames == self.inames
                and other.expr == self.expr)

    def stringifier(self):
        return StringifyMapper

    @property
    @memoize_method
    def untagged_inames(self):
        return tuple(iname.lstrip("@") for iname in self.inames)

    @property
    @memoize_method
    def untagged_inames_set(self):
        return set(self.untagged_inames)

    mapper_method = intern("map_reduction")

class LinearSubscript(AlgebraicLeaf):
    def __init__(self, aggregate, index):
        self.aggregate = aggregate
        self.index = index

    def __getinitargs__(self):
        return self.aggregate, self.index

    mapper_method = intern("map_linear_subscript")

# }}}

# {{{ mappers with support for loopy-specific primitives

class IdentityMapperMixin(object):
    def map_reduction(self, expr):
        return Reduction(expr.operation, expr.inames, self.rec(expr.expr))

    def map_tagged_variable(self, expr):
        # leaf, doesn't change
        return expr

    def map_loopy_function_identifier(self, expr):
        return expr

    map_linear_subscript = IdentityMapperBase.map_subscript

class IdentityMapper(IdentityMapperBase, IdentityMapperMixin):
    pass

class WalkMapper(WalkMapperBase):
    def map_reduction(self, expr):
        if not self.visit(expr):
            return

        self.rec(expr.expr)

    map_tagged_variable = WalkMapperBase.map_variable

    def map_loopy_function_identifier(self, expr):
        self.visit(expr)

    map_linear_subscript = WalkMapperBase.map_subscript

class CallbackMapper(CallbackMapperBase, IdentityMapper):
    map_reduction = CallbackMapperBase.map_constant

class CombineMapper(CombineMapperBase):
    def map_reduction(self, expr):
        return self.rec(expr.expr)

    map_linear_subscript = CombineMapperBase.map_subscript

class SubstitutionMapper(SubstitutionMapperBase, IdentityMapperMixin):
    pass

class StringifyMapper(StringifyMapperBase):
    def map_reduction(self, expr, prec):
        return "reduce(%s, [%s], %s)" % (
                expr.operation, ", ".join(expr.inames), expr.expr)

    def map_tagged_variable(self, expr, prec):
        return "%s$%s" % (expr.name, expr.tag)

    def map_linear_subscript(self, expr, enclosing_prec):
        from pymbolic.mapper.stringifier import PREC_CALL, PREC_NONE
        return self.parenthesize_if_needed(
                self.format("%s[[%s]]",
                    self.rec(expr.aggregate, PREC_CALL),
                    self.rec(expr.index, PREC_NONE)),
                enclosing_prec, PREC_CALL)


class DependencyMapper(DependencyMapperBase):
    def map_reduction(self, expr):
        from pymbolic.primitives import Variable
        return (self.rec(expr.expr)
                - set(Variable(iname) for iname in expr.untagged_inames))

    def map_tagged_variable(self, expr):
        return set([expr])

    def map_loopy_function_identifier(self, expr):
        return set()

    map_linear_subscript = DependencyMapperBase.map_subscript

class UnidirectionalUnifier(UnidirectionalUnifierBase):
    def map_reduction(self, expr, other, unis):
        if not isinstance(other, type(expr)):
            return self.treat_mismatch(expr, other, unis)
        if (expr.inames != other.inames
                or type(expr.operation) != type(other.operation)):
            return []

        return self.rec(expr.expr, other.expr, unis)

    def map_tagged_variable(self, expr, other, urecs):
        new_uni_record = self.unification_record_from_equation(
                expr, other)
        if new_uni_record is None:
            # Check if the variables match literally--that's ok, too.
            if (isinstance(other, TaggedVariable)
                    and expr.name == other.name
                    and expr.tag == other.tag
                    and expr.name not in self.lhs_mapping_candidates):
                return urecs
            else:
                return []
        else:
            from pymbolic.mapper.unifier import unify_many
            return unify_many(urecs, new_uni_record)

# }}}

# {{{ functions to primitives, parsing

class VarToTaggedVarMapper(IdentityMapper):
    def map_variable(self, expr):
        dollar_idx = expr.name.find("$")
        if dollar_idx == -1:
            return expr
        else:
            return TaggedVariable(expr.name[:dollar_idx],
                    expr.name[dollar_idx+1:])

class FunctionToPrimitiveMapper(IdentityMapper):
    """Looks for invocations of a function called 'cse' or 'reduce' and
    turns those into the actual pymbolic primitives used for that.
    """

    def map_call(self, expr):
        from pymbolic.primitives import Variable
        if not isinstance(expr.function, Variable):
            return IdentityMapper.map_call(self, expr)

        name = expr.function.name
        if name == "cse":
            from pymbolic.primitives import CommonSubexpression
            if len(expr.parameters) in [1, 2]:
                if len(expr.parameters) == 2:
                    if not isinstance(expr.parameters[1], Variable):
                        raise TypeError("second argument to cse() must be a symbol")
                    tag = expr.parameters[1].name
                else:
                    tag = None

                return CommonSubexpression(
                        self.rec(expr.parameters[0]), tag)
            else:
                raise TypeError("cse takes two arguments")

        elif name == "reduce":
            if len(expr.parameters) == 3:
                operation, inames, red_expr = expr.parameters
            else:
                raise TypeError("invalid 'reduce' calling sequence")

        elif name == "if":
            if len(expr.parameters) in [2, 3]:
                from pymbolic.primitives import If
                return If(*expr.parameters)
            else:
                raise TypeError("if takes two or three arguments")

        else:
            # see if 'name' is an existing reduction op

            from loopy.reduction import parse_reduction_op
            if parse_reduction_op(name):
                if len(expr.parameters) != 2:
                    raise RuntimeError("invalid invocation of "
                            "reduction operation '%s'" % expr.function.name)

                operation = expr.function
                inames, red_expr = expr.parameters
            else:
                return IdentityMapper.map_call(self, expr)

        red_expr = self.rec(red_expr)

        if not isinstance(operation, Variable):
            raise TypeError("operation argument to reduce() must be a symbol")
        operation = operation.name
        if isinstance(inames, Variable):
            inames = (inames,)

        if not isinstance(inames, (tuple)):
            raise TypeError("iname argument to reduce() must be a symbol "
                    "or a tuple of symbols")

        processed_inames = []
        for iname in inames:
            if not isinstance(iname, Variable):
                raise TypeError("iname argument to reduce() must be a symbol "
                        "or a tuple or a tuple of symbols")

            processed_inames.append(iname.name)

        return Reduction(operation, tuple(processed_inames), red_expr)

# {{{ parser extension

_open_dbl_bracket = intern("open_dbl_bracket")
_close_dbl_bracket = intern("close_dbl_bracket")

class LoopyParser(ParserBase):
    lex_table = [
            (_open_dbl_bracket, pytools.lex.RE(r"\[\[")),
            (_close_dbl_bracket, pytools.lex.RE(r"\]\]")),
            ] + ParserBase.lex_table

    def parse_postfix(self, pstate, min_precedence, left_exp):
        from pymbolic.parser import _PREC_CALL
        if pstate.next_tag() is _open_dbl_bracket and _PREC_CALL > min_precedence:
            pstate.advance()
            pstate.expect_not_end()
            left_exp = LinearSubscript(left_exp, self.parse_expression(pstate))
            pstate.expect(_close_dbl_bracket)
            pstate.advance()
            return left_exp, True

        return ParserBase.parse_postfix(self, pstate, min_precedence, left_exp)

# }}}

def parse(expr_str):
    return VarToTaggedVarMapper()(
            FunctionToPrimitiveMapper()(LoopyParser()(expr_str)))

# }}}

# {{{ reduction loop splitter

class ReductionLoopSplitter(IdentityMapper):
    def __init__(self, old_iname, outer_iname, inner_iname):
        self.old_iname = old_iname
        self.outer_iname = outer_iname
        self.inner_iname = inner_iname

    def map_reduction(self, expr):
        if self.old_iname in expr.inames:
            new_inames = list(expr.inames)
            new_inames.remove(self.old_iname)
            new_inames.extend([self.outer_iname, self.inner_iname])
            return Reduction(expr.operation, tuple(new_inames),
                        expr.expr)
        else:
            return IdentityMapper.map_reduction(self, expr)

# }}}

# {{{ coefficient collector

class CoefficientCollector(RecursiveMapper):
    def map_sum(self, expr):
        stride_dicts = [self.rec(ch) for ch in expr.children]

        result = {}
        for stride_dict in stride_dicts:
            for var, stride in stride_dict.iteritems():
                if var in result:
                    result[var] += stride
                else:
                    result[var] = stride

        return result

    def map_product(self, expr):
        result = {}

        children_coeffs = [self.rec(child) for child in expr.children]

        idx_of_child_with_vars = None
        for i, child_coeffs in enumerate(children_coeffs):
            for k in child_coeffs:
                if isinstance(k, str):
                    if (idx_of_child_with_vars is not None
                            and idx_of_child_with_vars != i):
                        raise RuntimeError(
                                "nonlinear expression")
                    idx_of_child_with_vars = i

        other_coeffs = 1
        for i, child_coeffs in enumerate(children_coeffs):
            if i != idx_of_child_with_vars:
                assert len(child_coeffs) == 1
                other_coeffs *= child_coeffs[1]

        if idx_of_child_with_vars is None:
            return {1: other_coeffs}
        else:
            return dict(
                    (var, other_coeffs*coeff)
                    for var, coeff in
                    children_coeffs[idx_of_child_with_vars].iteritems())

        return result

    def map_constant(self, expr):
        return {1: expr}

    def map_variable(self, expr):
        return {expr.name: 1}

    map_tagged_variable = map_variable

    def map_subscript(self, expr):
        raise RuntimeError("cannot gather coefficients--indirect addressing in use")

# }}}

# {{{ variable index expression collector

class ArrayAccessFinder(CombineMapper):
    def __init__(self, tgt_vector_name=None):
        self.tgt_vector_name = tgt_vector_name

    def combine(self, values):
        from pytools import flatten
        return set(flatten(values))

    def map_constant(self, expr):
        return set()

    def map_algebraic_leaf(self, expr):
        return set()

    def map_subscript(self, expr):
        from pymbolic.primitives import Variable
        assert isinstance(expr.aggregate, Variable)

        if self.tgt_vector_name is None or expr.aggregate.name == self.tgt_vector_name:
            return set([expr]) | self.rec(expr.index)
        else:
            return CombineMapper.map_subscript(self, expr)

# }}}

# {{{ aff <-> expr conversion

def aff_to_expr(aff, except_name=None, error_on_name=None):
    if except_name is not None and error_on_name is not None:
        raise ValueError("except_name and error_on_name may not be specified "
                "at the same time")
    from pymbolic import var

    except_coeff = 0

    result = int(aff.get_constant())
    for dt in [dim_type.in_, dim_type.param]:
        for i in xrange(aff.dim(dt)):
            coeff = int(aff.get_coefficient(dt, i))
            if coeff:
                dim_name = aff.get_dim_name(dt, i)
                if dim_name == except_name:
                    except_coeff += coeff
                elif dim_name == error_on_name:
                    raise RuntimeError("'%s' occurred in this subexpression--"
                            "this is not allowed" % dim_name)
                else:
                    result += coeff*var(dim_name)

    error_on_name = error_on_name or except_name

    for i in xrange(aff.dim(dim_type.div)):
        coeff = int(aff.get_coefficient(dim_type.div, i))
        if coeff:
            result += coeff*aff_to_expr(aff.get_div(i), error_on_name=error_on_name)

    denom = int(aff.get_denominator())
    if except_name is not None:
        if except_coeff % denom != 0:
            raise RuntimeError("coefficient of '%s' is not divisible by "
                    "aff denominator" % except_name)

        return result // denom, except_coeff // denom
    else:
        return result // denom




def pw_aff_to_expr(pw_aff, int_ok=False):
    if isinstance(pw_aff, int):
        if not int_ok:
            from warnings import warn
            warn("expected PwAff, got int", stacklevel=2)

        return pw_aff

    pieces = pw_aff.get_pieces()

    if len(pieces) != 1:
        raise NotImplementedError("pw_aff_to_expr for multi-piece PwAff instances")

    (set, aff), = pieces
    return aff_to_expr(aff)

def aff_from_expr(space, expr, vars_to_zero=set()):
    zero = isl.Aff.zero_on_domain(isl.LocalSpace.from_space(space))
    context = {}
    for name, (dt, pos) in space.get_var_dict().iteritems():
        if dt == dim_type.set:
            dt = dim_type.in_

        context[name] = zero.set_coefficient(dt, pos, 1)

    for name in vars_to_zero:
        context[name] = zero

    from pymbolic import evaluate
    return zero + evaluate(expr, context)

# }}}

# {{{ expression <-> constraint conversion

def eq_constraint_from_expr(space, expr):
    return isl.Constraint.equality_from_aff(aff_from_expr(space,expr))

def ineq_constraint_from_expr(space, expr):
    return isl.Constraint.inequality_from_aff(aff_from_expr(space,expr))

def constraint_to_expr(cns, except_name=None):
    # Looks like this is ok after all--get_aff() performs some magic.
    # Not entirely sure though... FIXME
    #
    #ls = cns.get_local_space()
    #if ls.dim(dim_type.div):
        #raise RuntimeError("constraint has an existentially quantified variable")

    return aff_to_expr(cns.get_aff(), except_name=except_name)

# }}}

# {{{ Reduction callback mapper

class ReductionCallbackMapper(IdentityMapper):
    def __init__(self, callback):
        self.callback = callback

    def map_reduction(self, expr):
        result = self.callback(expr, self.rec)
        if result is None:
            return IdentityMapper.map_reduction(self, expr)
        return result

# }}}

# {{{ index dependency finding

class IndexVariableFinder(CombineMapper):
    def __init__(self, include_reduction_inames):
        self.include_reduction_inames = include_reduction_inames

    def combine(self, values):
        import operator
        return reduce(operator.or_, values, set())

    def map_constant(self, expr):
        return set()

    def map_algebraic_leaf(self, expr):
        return set()

    def map_subscript(self, expr):
        idx_vars = DependencyMapper()(expr.index)

        from pymbolic.primitives import Variable
        result = set()
        for idx_var in idx_vars:
            if isinstance(idx_var, Variable):
                result.add(idx_var.name)
            else:
                raise RuntimeError("index variable not understood: %s" % idx_var)
        return result

    def map_reduction(self, expr):
        result = self.rec(expr.expr)

        if not (expr.untagged_inames_set & result):
            raise RuntimeError("reduction '%s' does not depend on "
                    "reduction inames (%s)" % (expr, ",".join(expr.inames)))
        if self.include_reduction_inames:
            return result
        else:
            return result - expr.untagged_inames_set

# }}}

# {{{ substitution callback mapper

class SubstitutionCallbackMapper(IdentityMapper):
    @staticmethod
    def parse_filter(filt):
        if not isinstance(filt, tuple):
            components = filt.split("$")
            if len(components) == 1:
                return (components[0], None)
            elif len(components) == 2:
                return tuple(components)
            else:
                raise RuntimeError("too many components in '%s'" % filt)
        else:
            if len(filt) != 2:
                raise RuntimeError("substitution name filters "
                        "may have at most two components")

            return filt

    def __init__(self, names_filter, func):
        if names_filter is not None:
            new_names_filter = []
            for filt in names_filter:
                new_names_filter.append(self.parse_filter(filt))

            self.names_filter = new_names_filter
        else:
            self.names_filter = names_filter

        self.func = func

    def parse_name(self, expr):
        from pymbolic.primitives import Variable
        if isinstance(expr, TaggedVariable):
            e_name, e_tag = expr.name, expr.tag
        elif isinstance(expr, Variable):
            e_name, e_tag = expr.name, None
        else:
            return None

        if self.names_filter is not None:
            for filt_name, filt_tag in self.names_filter:
                if e_name == filt_name:
                    if filt_tag is None or filt_tag == e_tag:
                        return e_name, e_tag
        else:
            return e_name, e_tag

        return None

    def map_variable(self, expr):
        parsed_name = self.parse_name(expr)
        if parsed_name is None:
            return getattr(IdentityMapper, expr.mapper_method)(self, expr)

        name, tag = parsed_name

        result = self.func(expr, name, tag, (), self.rec)
        if result is None:
            return getattr(IdentityMapper, expr.mapper_method)(self, expr)
        else:
            return result

    map_tagged_variable = map_variable

    def map_call(self, expr):
        from pymbolic.primitives import Lookup
        if isinstance(expr.function, Lookup):
            raise RuntimeError("dotted name '%s' not allowed as "
                    "function identifier" % expr.function)

        parsed_name = self.parse_name(expr.function)

        if parsed_name is None:
            return IdentityMapper.map_call(self, expr)

        name, tag = parsed_name

        result = self.func(expr, name, tag, expr.parameters, self.rec)
        if result is None:
            return IdentityMapper.map_call(self, expr)
        else:
            return result

# }}}

# {{{ parametrized substitutor

class ParametrizedSubstitutor(object):
    def __init__(self, rules, one_level=False):
        self.rules = rules
        self.one_level = one_level

    def __call__(self, expr):
        level = [0]

        def expand_if_known(expr, name, instance, args, rec):
            if self.one_level and level[0] > 0:
                return None

            rule = self.rules[name]
            if len(rule.arguments) != len(args):
                raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
                        % (name, len(args), len(rule.arguments), ))

            from pymbolic.mapper.substitutor import make_subst_func
            subst_map = SubstitutionMapper(make_subst_func(
                dict(zip(rule.arguments, args))))

            level[0] += 1
            result = rec(subst_map(rule.expression))
            level[0] -= 1

            return result

        scm = SubstitutionCallbackMapper(self.rules.keys(), expand_if_known)
        return scm(expr)

# }}}

# {{{ wildcard -> unique variable mapper

class WildcardToUniqueVariableMapper(IdentityMapper):
    def __init__(self, unique_var_name_factory):
        self.unique_var_name_factory = unique_var_name_factory

    def map_wildcard(self, expr):
        from pymbolic import var
        return var(self.unique_var_name_factory())

# }}}

# {{{ prime-adder

class PrimeAdder(IdentityMapper):
    def __init__(self, which_vars):
        self.which_vars = which_vars

    def map_variable(self, expr):
        from pymbolic import var
        if expr.name in self.which_vars:
            return var(expr.name+"'")
        else:
            return expr

    def map_tagged_variable(self, expr):
        if expr.name in self.which_vars:
            return TaggedVariable(expr.name+"'", expr.tag)
        else:
            return expr


# }}}

@memoize
def get_dependencies(expr):
    from loopy.symbolic import DependencyMapper
    dep_mapper = DependencyMapper(composite_leaves=False)

    return frozenset(dep.name for dep in dep_mapper(expr))



# vim: foldmethod=marker
