From efcc7cace40cd754423bc705ef336509be27d876 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Mon, 28 Mar 2016 13:12:43 -0500 Subject: [PATCH] Fix unidirectional unification. The unidirectional unifier now works on examples such as a+b*c => d+e*f. This required a couple of changes to the algorithm, the main one being that we allow variables to match against subtrees of the expression. --- pymbolic/mapper/unifier.py | 62 +++++++++++++++++--------------------- test/test_pymbolic.py | 26 ++++++++++++++++ 2 files changed, 54 insertions(+), 34 deletions(-) diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index 6187e74..346d2ac 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -326,52 +326,46 @@ class UnidirectionalUnifier(UnifierBase): if not isinstance(other, type(expr)): return - plain_cand_variables = [] - non_var_children = [] - for child in expr.children: - if (isinstance(child, Variable) - and child.name in self.lhs_mapping_candidates): - plain_cand_variables.append(child) - else: - non_var_children.append(child) - - # list (with indices matching non_var_children) of - # list of tuples (other_index, unifiers) unification_candidates = [] - for i, my_child in enumerate(non_var_children): - i_matches = [] - for j, other_child in enumerate(other.children): - result = self.rec(my_child, other_child, urecs) + def subsets_modulo_ac(n): + from itertools import combinations + base = tuple(range(n)) + for r in range(1, n+1): + for p in combinations(base, r): + yield p + + # Unify children of expr with subtrees of the other expr. + for i, my_child in enumerate(expr.children): + i_matches = {} + for subset in subsets_modulo_ac(len(other.children)): + subtree = factory(other.children[idx] for idx in subset) + result = self.rec(my_child, subtree, urecs) if result: - i_matches.append((j, result)) - + i_matches[frozenset(subset)] = result unification_candidates.append(i_matches) - def match_children(urec, next_my_idx, other_leftovers): - if next_my_idx >= len(non_var_children): - if not plain_cand_variables and other_leftovers: - return - - eqns = [] - for pv in plain_cand_variables: - eqns.append((pv, factory( - other.children[i] for i in other_leftovers))) - other_leftovers = [] - - yield urec.unify(UnificationRecord(eqns)) + # Combine the unification candidates of children in all possible ways. + def match_children(urec, next_cand_idx, other_leftovers): + if next_cand_idx >= len(expr.children): + if len(other_leftovers) == 0: + # Only return records that are fully matched. + yield urec return - for other_idx, pair_urecs in unification_candidates[next_my_idx]: - if other_idx not in other_leftovers: + from six import iteritems + for other_idxs, pair_urecs in iteritems( + unification_candidates[next_cand_idx]): + if not other_idxs <= other_leftovers: + # Don't re-match any elements. continue new_urecs = unify_many(pair_urecs, urec) - new_rhs_leftovers = other_leftovers - set([other_idx]) + new_rhs_leftovers = other_leftovers - other_idxs for cand_urec in new_urecs: - for result_urec in match_children(cand_urec, next_my_idx+1, - new_rhs_leftovers): + for result_urec in match_children( + cand_urec, next_cand_idx + 1, new_rhs_leftovers): yield result_urec for urec in match_children( diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index a81871a..9207449 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -442,6 +442,32 @@ def test_compile(): assert code(3, 3) == 27 +def test_unifier(): + from pymbolic import var + from pymbolic.mapper.unifier import UnidirectionalUnifier + a, b, c, d, e, f = [var(s) for s in "abcdef"] + + def match_found(records, eqns): + for record in records: + if eqns <= set(record.equations): + return True + return False + + recs = UnidirectionalUnifier("abc")(a+b*c, d+e*f) + assert len(recs) == 2 + assert match_found(recs, set([(a, d), (b, e), (c, f)])) + assert match_found(recs, set([(a, d), (b, f), (c, e)])) + + recs = UnidirectionalUnifier("abc")(a+b, d+e+f) + assert len(recs) == 6 + assert match_found(recs, set([(a, d), (b, e+f)])) + assert match_found(recs, set([(a, e), (b, d+f)])) + assert match_found(recs, set([(a, f), (b, d+e)])) + assert match_found(recs, set([(b, d), (a, e+f)])) + assert match_found(recs, set([(b, e), (a, d+f)])) + assert match_found(recs, set([(b, f), (a, d+e)])) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab