Skip to content
Snippets Groups Projects
Commit cb697546 authored by Matt Wala's avatar Matt Wala
Browse files

Fix avoidable exponential blowup in unifier.

Change the code to do what the previous code tried to do, which is to
match free variables only after matching the more complex terms. This
avoids iterating over all possible subsets of the RHS.
parent 0a635ffb
No related branches found
No related tags found
No related merge requests found
......@@ -326,48 +326,86 @@ class UnidirectionalUnifier(UnifierBase):
if not isinstance(other, type(expr)):
return
# Partition expr into terms that are plain (free) variables and those
# that are not.
plain_var_candidates = []
non_var_children = []
for child in expr.children:
if (isinstance(child, Variable)
and child.name in self.lhs_mapping_candidates):
plain_var_candidates.append(child)
else:
non_var_children.append(child)
# list (with indices matching non_var_children) of
# list of tuples (other_index, unifiers)
unification_candidates = []
def subsets_modulo_ac(n):
from itertools import combinations
base = tuple(range(n))
for r in range(1, n+1):
for p in combinations(base, r):
yield p
# Unify children of expr with subtrees of the other expr.
for i, my_child in enumerate(expr.children):
i_matches = {}
for subset in subsets_modulo_ac(len(other.children)):
subtree = factory(other.children[idx] for idx in subset)
result = self.rec(my_child, subtree, urecs)
# Unify non-free-variable children of expr with children of the other
# expr.
for i, my_child in enumerate(non_var_children):
i_matches = []
for j, other_child in enumerate(other.children):
result = self.rec(my_child, other_child, urecs)
if result:
i_matches[frozenset(subset)] = result
i_matches.append((j, result))
unification_candidates.append(i_matches)
# Combine the unification candidates of children in all possible ways.
def match_children(urec, next_cand_idx, other_leftovers):
if next_cand_idx >= len(expr.children):
if len(other_leftovers) == 0:
# Only return records that are fully matched.
yield urec
if next_cand_idx >= len(non_var_children):
for match in match_plain_var_candidates(urec, other_leftovers):
yield match
return
from six import iteritems
for other_idxs, pair_urecs in iteritems(
unification_candidates[next_cand_idx]):
if not other_idxs <= other_leftovers:
for other_idx, pair_urecs in unification_candidates[next_cand_idx]:
if other_idx not in other_leftovers:
# Don't re-match any elements.
continue
new_urecs = unify_many(pair_urecs, urec)
new_rhs_leftovers = other_leftovers - other_idxs
new_rhs_leftovers = other_leftovers - set([other_idx])
for cand_urec in new_urecs:
for result_urec in match_children(
cand_urec, next_cand_idx + 1, new_rhs_leftovers):
yield result_urec
def match_plain_var_candidates(urec, other_leftovers):
if len(plain_var_candidates) == len(other_leftovers) == 0:
yield urec
return
# At this point, the values in plain_var_candidates have not
# been matched in the lhs, and the values in other_leftovers
# have not been matched in the rhs. Try all possible
# combinations of matches (this part may become a performance
# bottleneck and if needed could be optimized further).
def subsets(s, max_size):
from itertools import combinations
for size in range(1, max_size + 1):
for subset in combinations(s, size):
yield subset
def partitions(s, k):
if k == 1:
yield [s]
return
for subset in map(set, subsets(s, len(s) - k + 1)):
for partition in partitions(s - subset, k - 1):
yield [subset] + partition
for partition in partitions(
other_leftovers, len(plain_var_candidates)):
result = urec
for subset, var in zip(partition, plain_var_candidates):
eqn = (var, factory(other.children[i] for i in subset))
result = result.unify(UnificationRecord([eqn]))
if not result:
break
else:
yield result
for urec in match_children(
UnificationRecord([]), 0, set(range(len(other.children)))):
yield urec
......
......@@ -467,6 +467,11 @@ def test_unifier():
assert match_found(recs, set([(b, e), (a, d+f)]))
assert match_found(recs, set([(b, f), (a, d+e)]))
vals = [var("v" + str(i)) for i in range(100)]
recs = UnidirectionalUnifier("a")(sum(vals[1:]) + a, sum(vals))
assert len(recs) == 1
assert match_found(recs, set([(a, var("v0"))]))
if __name__ == "__main__":
import sys
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment