diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index e13e56382b7be5c64437354c60a73d8f13fa97a3..b15900933b37d0461d7b4b39274b0c755cdfdacd 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -101,41 +101,41 @@ class UnifierBase(RecursiveMapper): return UnificationRecord([(lhs, rhs)]) - def map_constant(self, expr, other, unis): + def map_constant(self, expr, other, urecs): if expr == other: - return unis + return urecs else: return [] - def map_variable(self, expr, other, unis): + def map_variable(self, expr, other, urecs): new_uni_record = self.unification_record_from_equation( expr, other) if new_uni_record is None: if (isinstance(other, Variable) and other.name == expr.name and expr.name not in self.lhs_mapping_candidates): - return unis + return urecs else: return [] else: - return unify_many(unis, new_uni_record) + return unify_many(urecs, new_uni_record) - def map_subscript(self, expr, other, unis): + def map_subscript(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return self.rec(expr.aggregate, other.aggregate, - self.rec(expr.index, other.index, unis)) + self.rec(expr.index, other.index, urecs)) - def map_lookup(self, expr, other, unis): + def map_lookup(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) if expr.name != other.name: return [] - return self.rec(expr.aggregate, other.aggregate, unis) + return self.rec(expr.aggregate, other.aggregate, urecs) - def map_sum(self, expr, other, unis): + def map_sum(self, expr, other, urecs): if (not isinstance(other, type(expr)) or len(expr.children) != len(other.children)): return [] @@ -145,7 +145,7 @@ class UnifierBase(RecursiveMapper): from pytools import generate_permutations had_structural_match = False for perm in generate_permutations(range(len(expr.children))): - it_assignments = unis + it_assignments = urecs for my_child, other_child in zip( expr.children, @@ -159,52 +159,52 @@ class UnifierBase(RecursiveMapper): result.extend(it_assignments) if not had_structural_match: - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return result map_product = map_sum - def map_negation(self, expr, other, unis): + def map_negation(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) - return self.rec(expr.child, other.child, unis) + return self.treat_mismatch(expr, other, urecs) + return self.rec(expr.child, other.child, urecs) - def map_quotient(self, expr, other, unis): + def map_quotient(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return self.rec(expr.numerator, other.numerator, - self.rec(expr.denominator, other.denominator, unis)) + self.rec(expr.denominator, other.denominator, urecs)) map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr, other, unis): + def map_power(self, expr, other, urecs): if not isinstance(other, type(expr)): - return self.treat_mismatch(expr, other, unis) + return self.treat_mismatch(expr, other, urecs) return self.rec(expr.base, other.base, - self.rec(expr.exponent, other.exponent, unis)) + self.rec(expr.exponent, other.exponent, urecs)) - def map_list(self, expr, other, unis): + def map_list(self, expr, other, urecs): if (not isinstance(other, type(expr)) or len(expr) != len(other)): return [] for my_child, other_child in zip(expr, other): - unis = self.rec(my_child, other_child, unis) - if not unis: + urecs = self.rec(my_child, other_child, urecs) + if not urecs: break - return unis + return urecs map_tuple = map_list - def __call__(self, expr, other, unis=None): - if unis is None: - unis = [UnificationRecord([])] - return self.rec(expr, other, unis) + def __call__(self, expr, other, urecs=None): + if urecs is None: + urecs = [UnificationRecord([])] + return self.rec(expr, other, urecs) @@ -214,14 +214,69 @@ class UnidirectionalUnifier(UnifierBase): subexpression of the second. """ - def treat_mismatch(self, expr, other, unis): + def treat_mismatch(self, expr, other, urecs): return [] + def map_commut_assoc(self, expr, other, urecs, factory): + if not isinstance(other, type(expr)): + return + + plain_cand_variables = [] + non_var_children = [] + for child in expr.children: + if (isinstance(child, Variable) + and child.name in self.lhs_mapping_candidates): + plain_cand_variables.append(child) + else: + non_var_children.append(child) + # list (with indices matching non_var_children) of + # list of tuples (other_index, unifiers) + unification_candidates = [] -class BidirectionalUnifier(UnifierBase): - """Only assigns variables encountered in the first expression to - subexpression of the second. - """ + for i, my_child in enumerate(non_var_children): + i_matches = [] + for j, other_child in enumerate(other.children): + result = self.rec(my_child, other_child, urecs) + if result: + i_matches.append((j, result)) + + unification_candidates.append(i_matches) + + def match_children(urec, next_my_idx, other_leftovers): + if next_my_idx >= len(non_var_children): + if not plain_cand_variables and other_leftovers: + return + + eqns = [] + for pv in plain_cand_variables: + eqns.append((pv, factory( + other.children[i] for i in other_leftovers))) + other_leftovers = [] + + yield urec.unify(UnificationRecord(eqns)) + return + + for other_idx, pair_urecs in unification_candidates[next_my_idx]: + if other_idx not in other_leftovers: + continue + + new_urecs = unify_many(pair_urecs, urec) + new_rhs_leftovers = other_leftovers - set([other_idx]) + + for cand_urec in new_urecs: + for result_urec in match_children(cand_urec, next_my_idx+1, + new_rhs_leftovers): + yield result_urec + + for urec in match_children( + UnificationRecord([]), 0, set(range(len(other.children)))): + yield urec + + def map_sum(self, expr, other, unis): + from pymbolic.primitives import flattened_sum + return list(self.map_commut_assoc(expr, other, unis, flattened_sum)) - treat_mismatch = UnifierBase.map_variable + def map_product(self, expr, other, unis): + from pymbolic.primitives import flattened_product + return list(self.map_commut_assoc(expr, other, unis, flattened_product))