diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index b9fa9d7fa195e3006184fe32eaf765f815f4ee80..55f7a73f0585f8c7d3547787285dc5bcbe4404b9 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -30,8 +30,6 @@ from pymbolic.mapper import RecursiveMapper from pymbolic.primitives import Variable - - def unify_map(map1, map2): result = map1.copy() for name, value in six.iteritems(map2): @@ -44,8 +42,6 @@ def unify_map(map1, map2): return result - - class UnificationRecord(object): def __init__(self, equations, lmap=None, rmap=None): self.equations = equations @@ -85,8 +81,6 @@ class UnificationRecord(object): for lhs, rhs in self.equations)) - - def unify_many(unis1, uni2): result = [] for uni1 in unis1: @@ -97,8 +91,6 @@ def unify_many(unis1, uni2): return result - - class UnifierBase(RecursiveMapper): # The idea of the algorithm here is that the unifier accumulates a set of # unification possibilities (:class:`UnificationRecord`) as it descends the @@ -174,8 +166,20 @@ class UnifierBase(RecursiveMapper): if not isinstance(other, type(expr)): return self.treat_mismatch(expr, other, urecs) + # {{{ unpack length-1 index tuples to avoid ambiguity + + expr_index = expr.index + if isinstance(expr_index, tuple) and len(expr_index) == 1: + expr_index, = expr_index + + other_index = other.index + if isinstance(other_index, tuple) and len(other_index) == 1: + other_index, = other_index + + # }}} + return self.rec(expr.aggregate, other.aggregate, - self.rec(expr.index, other.index, urecs)) + self.rec(expr_index, other_index, urecs)) def map_lookup(self, expr, other, urecs): if not isinstance(other, type(expr)): @@ -252,8 +256,6 @@ class UnifierBase(RecursiveMapper): return self.rec(expr, other, urecs) - - class UnidirectionalUnifier(UnifierBase): """Finds assignments of variables encountered in the first ("left") expression to subexpression of the second