From 753dfa79c2acdfee52c3177093078f3b3c0819cb Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 26 Mar 2015 02:08:51 -0500
Subject: [PATCH] Unify single-index subscripts with and without tuple wrapping

---
 pymbolic/mapper/unifier.py | 24 +++++++++++++-----------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py
index b9fa9d7..55f7a73 100644
--- a/pymbolic/mapper/unifier.py
+++ b/pymbolic/mapper/unifier.py
@@ -30,8 +30,6 @@ from pymbolic.mapper import RecursiveMapper
 from pymbolic.primitives import Variable
 
 
-
-
 def unify_map(map1, map2):
     result = map1.copy()
     for name, value in six.iteritems(map2):
@@ -44,8 +42,6 @@ def unify_map(map1, map2):
     return result
 
 
-
-
 class UnificationRecord(object):
     def __init__(self, equations, lmap=None, rmap=None):
         self.equations = equations
@@ -85,8 +81,6 @@ class UnificationRecord(object):
                 for lhs, rhs in self.equations))
 
 
-
-
 def unify_many(unis1, uni2):
     result = []
     for uni1 in unis1:
@@ -97,8 +91,6 @@ def unify_many(unis1, uni2):
     return result
 
 
-
-
 class UnifierBase(RecursiveMapper):
     # The idea of the algorithm here is that the unifier accumulates a set of
     # unification possibilities (:class:`UnificationRecord`) as it descends the
@@ -174,8 +166,20 @@ class UnifierBase(RecursiveMapper):
         if not isinstance(other, type(expr)):
             return self.treat_mismatch(expr, other, urecs)
 
+        # {{{ unpack length-1 index tuples to avoid ambiguity
+
+        expr_index = expr.index
+        if isinstance(expr_index, tuple) and len(expr_index) == 1:
+            expr_index, = expr_index
+
+        other_index = other.index
+        if isinstance(other_index, tuple) and len(other_index) == 1:
+            other_index, = other_index
+
+        # }}}
+
         return self.rec(expr.aggregate, other.aggregate,
-                self.rec(expr.index, other.index, urecs))
+                self.rec(expr_index, other_index, urecs))
 
     def map_lookup(self, expr, other, urecs):
         if not isinstance(other, type(expr)):
@@ -252,8 +256,6 @@ class UnifierBase(RecursiveMapper):
         return self.rec(expr, other, urecs)
 
 
-
-
 class UnidirectionalUnifier(UnifierBase):
     """Finds assignments of variables encountered in the
     first ("left") expression to subexpression of the second
-- 
GitLab