From 3b7bdf6c7865ecf5c933cbcbc47e94fdc958a3b7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 3 Nov 2011 03:35:39 -0400 Subject: [PATCH] Unifier: Add force_var_match. --- pymbolic/mapper/unifier.py | 80 ++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index 0ebe1c8..e13e563 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -73,21 +73,29 @@ def unify_many(unis1, uni2): class UnifierBase(RecursiveMapper): def __init__(self, lhs_mapping_candidates=None, - rhs_mapping_candidates=None): + rhs_mapping_candidates=None, + force_var_match=True): self.lhs_mapping_candidates = lhs_mapping_candidates self.rhs_mapping_candidates = rhs_mapping_candidates + self.force_var_match = force_var_match def unification_record_from_equation(self, lhs, rhs): if isinstance(lhs, (tuple, list)) or isinstance(rhs, (tuple, list)): # must match elementwise! return None + lhs_is_var = isinstance(lhs, Variable) + rhs_is_var = isinstance(rhs, Variable) + + if self.force_var_match and not (lhs_is_var or rhs_is_var): + return None + if (self.lhs_mapping_candidates is not None - and isinstance(lhs, Variable) + and lhs_is_var and lhs.name not in self.lhs_mapping_candidates): return None if (self.rhs_mapping_candidates is not None - and isinstance(rhs, Variable) + and rhs_is_var and rhs.name not in self.rhs_mapping_candidates): return None @@ -122,17 +130,46 @@ class UnifierBase(RecursiveMapper): def map_lookup(self, expr, other, unis): if not isinstance(other, type(expr)): return self.treat_mismatch(expr, other, unis) - if self.name != other.name: + if expr.name != other.name: return [] return self.rec(expr.aggregate, other.aggregate, unis) + def map_sum(self, expr, other, unis): + if (not isinstance(other, type(expr)) + or len(expr.children) != len(other.children)): + return [] + + result = [] + + from pytools import generate_permutations + had_structural_match = False + for perm in generate_permutations(range(len(expr.children))): + it_assignments = unis + + for my_child, other_child in zip( + expr.children, + (other.children[i] for i in perm)): + it_assignments = self.rec(my_child, other_child, it_assignments) + if not it_assignments: + break + + if it_assignments: + had_structural_match = True + result.extend(it_assignments) + + if not had_structural_match: + return self.treat_mismatch(expr, other, unis) + + return result + + map_product = map_sum + def map_negation(self, expr, other, unis): if not isinstance(other, type(expr)): return self.treat_mismatch(expr, other, unis) return self.rec(expr.child, other.child, unis) - def map_quotient(self, expr, other, unis): if not isinstance(other, type(expr)): return self.treat_mismatch(expr, other, unis) @@ -162,7 +199,6 @@ class UnifierBase(RecursiveMapper): return unis - map_tuple = map_list def __call__(self, expr, other, unis=None): @@ -178,36 +214,6 @@ class UnidirectionalUnifier(UnifierBase): subexpression of the second. """ - def map_sum(self, expr, other, unis): - if (not isinstance(other, type(expr)) - or len(expr.children) != len(other.children)): - return [] - - result = [] - - from pytools import generate_permutations - had_structural_match = False - for perm in generate_permutations(range(len(expr.children))): - it_assignments = unis - - for my_child, other_child in zip( - expr.children, - (other.children[i] for i in perm)): - it_assignments = self.rec(my_child, other_child, it_assignments) - if not it_assignments: - break - - if it_assignments: - had_structural_match = True - result.extend(it_assignments) - - if not had_structural_match: - return self.treat_mismatch(expr, other, unis) - - return result - - map_product = map_sum - def treat_mismatch(self, expr, other, unis): return [] @@ -219,5 +225,3 @@ class BidirectionalUnifier(UnifierBase): """ treat_mismatch = UnifierBase.map_variable - map_sum = UnifierBase.map_variable - map_product = UnifierBase.map_variable -- GitLab