From 3b7bdf6c7865ecf5c933cbcbc47e94fdc958a3b7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 3 Nov 2011 03:35:39 -0400
Subject: [PATCH] Unifier: Add force_var_match.

---
 pymbolic/mapper/unifier.py | 80 ++++++++++++++++++++------------------
 1 file changed, 42 insertions(+), 38 deletions(-)

diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py
index 0ebe1c8..e13e563 100644
--- a/pymbolic/mapper/unifier.py
+++ b/pymbolic/mapper/unifier.py
@@ -73,21 +73,29 @@ def unify_many(unis1, uni2):
 
 class UnifierBase(RecursiveMapper):
     def __init__(self, lhs_mapping_candidates=None,
-            rhs_mapping_candidates=None):
+            rhs_mapping_candidates=None,
+            force_var_match=True):
         self.lhs_mapping_candidates = lhs_mapping_candidates
         self.rhs_mapping_candidates = rhs_mapping_candidates
+        self.force_var_match = force_var_match
 
     def unification_record_from_equation(self, lhs, rhs):
         if isinstance(lhs, (tuple, list)) or isinstance(rhs, (tuple, list)):
             # must match elementwise!
             return None
 
+        lhs_is_var = isinstance(lhs, Variable)
+        rhs_is_var = isinstance(rhs, Variable)
+
+        if self.force_var_match and not (lhs_is_var or rhs_is_var):
+            return None
+
         if (self.lhs_mapping_candidates is not None
-                and isinstance(lhs, Variable)
+                and lhs_is_var
                 and lhs.name not in self.lhs_mapping_candidates):
             return None
         if (self.rhs_mapping_candidates is not None
-                and isinstance(rhs, Variable)
+                and rhs_is_var
                 and rhs.name not in self.rhs_mapping_candidates):
             return None
 
@@ -122,17 +130,46 @@ class UnifierBase(RecursiveMapper):
     def map_lookup(self, expr, other, unis):
         if not isinstance(other, type(expr)):
             return self.treat_mismatch(expr, other, unis)
-        if self.name != other.name:
+        if expr.name != other.name:
             return []
 
         return self.rec(expr.aggregate, other.aggregate, unis)
 
+    def map_sum(self, expr, other, unis):
+        if (not isinstance(other, type(expr))
+                or len(expr.children) != len(other.children)):
+            return []
+
+        result = []
+
+        from pytools import generate_permutations
+        had_structural_match = False
+        for perm in generate_permutations(range(len(expr.children))):
+            it_assignments = unis
+
+            for my_child, other_child in zip(
+                    expr.children,
+                    (other.children[i] for i in perm)):
+                it_assignments = self.rec(my_child, other_child, it_assignments)
+                if not it_assignments:
+                    break
+
+            if it_assignments:
+                had_structural_match = True
+                result.extend(it_assignments)
+
+        if not had_structural_match:
+            return self.treat_mismatch(expr, other, unis)
+
+        return result
+
+    map_product = map_sum
+
     def map_negation(self, expr, other, unis):
         if not isinstance(other, type(expr)):
             return self.treat_mismatch(expr, other, unis)
         return self.rec(expr.child, other.child, unis)
 
-
     def map_quotient(self, expr, other, unis):
         if not isinstance(other, type(expr)):
             return self.treat_mismatch(expr, other, unis)
@@ -162,7 +199,6 @@ class UnifierBase(RecursiveMapper):
 
         return unis
 
-
     map_tuple = map_list
 
     def __call__(self, expr, other, unis=None):
@@ -178,36 +214,6 @@ class UnidirectionalUnifier(UnifierBase):
     subexpression of the second.
     """
 
-    def map_sum(self, expr, other, unis):
-        if (not isinstance(other, type(expr))
-                or len(expr.children) != len(other.children)):
-            return []
-
-        result = []
-
-        from pytools import generate_permutations
-        had_structural_match = False
-        for perm in generate_permutations(range(len(expr.children))):
-            it_assignments = unis
-
-            for my_child, other_child in zip(
-                    expr.children,
-                    (other.children[i] for i in perm)):
-                it_assignments = self.rec(my_child, other_child, it_assignments)
-                if not it_assignments:
-                    break
-
-            if it_assignments:
-                had_structural_match = True
-                result.extend(it_assignments)
-
-        if not had_structural_match:
-            return self.treat_mismatch(expr, other, unis)
-
-        return result
-
-    map_product = map_sum
-
     def treat_mismatch(self, expr, other, unis):
         return []
 
@@ -219,5 +225,3 @@ class BidirectionalUnifier(UnifierBase):
     """
 
     treat_mismatch = UnifierBase.map_variable
-    map_sum = UnifierBase.map_variable
-    map_product = UnifierBase.map_variable
-- 
GitLab