diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index bdc783cb93d6e9f994288086b9eea7f0c6aa48a8..0ebe1c83a018e548c7895cc47d455e8e35253d99 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -72,22 +72,26 @@ def unify_many(unis1, uni2): class UnifierBase(RecursiveMapper): - def __init__(self, mapping_candidates=None): - self.mapping_candidates = mapping_candidates + def __init__(self, lhs_mapping_candidates=None, + rhs_mapping_candidates=None): + self.lhs_mapping_candidates = lhs_mapping_candidates + self.rhs_mapping_candidates = rhs_mapping_candidates def unification_record_from_equation(self, lhs, rhs): if isinstance(lhs, (tuple, list)) or isinstance(rhs, (tuple, list)): # must match elementwise! return None - if self.mapping_candidates is None: - return UnificationRecord([(lhs, rhs)]) - else: - if isinstance(lhs, Variable) and lhs.name not in self.mapping_candidates: - return None - if isinstance(rhs, Variable) and rhs.name not in self.mapping_candidates: - return None - return UnificationRecord([(lhs, rhs)]) + if (self.lhs_mapping_candidates is not None + and isinstance(lhs, Variable) + and lhs.name not in self.lhs_mapping_candidates): + return None + if (self.rhs_mapping_candidates is not None + and isinstance(rhs, Variable) + and rhs.name not in self.rhs_mapping_candidates): + return None + + return UnificationRecord([(lhs, rhs)]) def map_constant(self, expr, other, unis): if expr == other: @@ -99,9 +103,9 @@ class UnifierBase(RecursiveMapper): new_uni_record = self.unification_record_from_equation( expr, other) if new_uni_record is None: - if (isinstance(other, Variable) + if (isinstance(other, Variable) and other.name == expr.name - and expr.name not in self.mapping_candidates): + and expr.name not in self.lhs_mapping_candidates): return unis else: return [] @@ -128,37 +132,6 @@ class UnifierBase(RecursiveMapper): return self.treat_mismatch(expr, other, unis) return self.rec(expr.child, other.child, unis) - def map_sum(self, expr, other, unis): - if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) - - if len(expr.children) != len(other.children): - return [] - - result = [] - - from pytools import generate_permutations - had_structural_match = False - for perm in generate_permutations(range(len(expr.children))): - it_assignments = unis - - for my_child, other_child in zip( - expr.children, - (other.children[i] for i in perm)): - it_assignments = self.rec(my_child, other_child, it_assignments) - if not it_assignments: - break - - if it_assignments: - had_structural_match = True - result.extend(it_assignments) - - if not had_structural_match: - return self.treat_mismatch(expr, other, unis) - - return result - - map_product = map_sum def map_quotient(self, expr, other, unis): if not isinstance(other, type(expr)): @@ -189,7 +162,6 @@ class UnifierBase(RecursiveMapper): return unis - map_product = map_sum map_tuple = map_list @@ -206,6 +178,36 @@ class UnidirectionalUnifier(UnifierBase): subexpression of the second. """ + def map_sum(self, expr, other, unis): + if (not isinstance(other, type(expr)) + or len(expr.children) != len(other.children)): + return [] + + result = [] + + from pytools import generate_permutations + had_structural_match = False + for perm in generate_permutations(range(len(expr.children))): + it_assignments = unis + + for my_child, other_child in zip( + expr.children, + (other.children[i] for i in perm)): + it_assignments = self.rec(my_child, other_child, it_assignments) + if not it_assignments: + break + + if it_assignments: + had_structural_match = True + result.extend(it_assignments) + + if not had_structural_match: + return self.treat_mismatch(expr, other, unis) + + return result + + map_product = map_sum + def treat_mismatch(self, expr, other, unis): return [] @@ -217,3 +219,5 @@ class BidirectionalUnifier(UnifierBase): """ treat_mismatch = UnifierBase.map_variable + map_sum = UnifierBase.map_variable + map_product = UnifierBase.map_variable