Skip to content
Snippets Groups Projects
Commit e6ebdbcd authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge pull request #8 from mattwala/unifier-performance-fix

Fix avoidable exponential blowup in unifier.
parents a59c966c cb697546
No related branches found
No related tags found
No related merge requests found
Pipeline #567 passed with stage
...@@ -326,48 +326,86 @@ class UnidirectionalUnifier(UnifierBase): ...@@ -326,48 +326,86 @@ class UnidirectionalUnifier(UnifierBase):
if not isinstance(other, type(expr)): if not isinstance(other, type(expr)):
return return
# Partition expr into terms that are plain (free) variables and those
# that are not.
plain_var_candidates = []
non_var_children = []
for child in expr.children:
if (isinstance(child, Variable)
and child.name in self.lhs_mapping_candidates):
plain_var_candidates.append(child)
else:
non_var_children.append(child)
# list (with indices matching non_var_children) of
# list of tuples (other_index, unifiers)
unification_candidates = [] unification_candidates = []
def subsets_modulo_ac(n): # Unify non-free-variable children of expr with children of the other
from itertools import combinations # expr.
base = tuple(range(n)) for i, my_child in enumerate(non_var_children):
for r in range(1, n+1): i_matches = []
for p in combinations(base, r): for j, other_child in enumerate(other.children):
yield p result = self.rec(my_child, other_child, urecs)
# Unify children of expr with subtrees of the other expr.
for i, my_child in enumerate(expr.children):
i_matches = {}
for subset in subsets_modulo_ac(len(other.children)):
subtree = factory(other.children[idx] for idx in subset)
result = self.rec(my_child, subtree, urecs)
if result: if result:
i_matches[frozenset(subset)] = result i_matches.append((j, result))
unification_candidates.append(i_matches) unification_candidates.append(i_matches)
# Combine the unification candidates of children in all possible ways. # Combine the unification candidates of children in all possible ways.
def match_children(urec, next_cand_idx, other_leftovers): def match_children(urec, next_cand_idx, other_leftovers):
if next_cand_idx >= len(expr.children): if next_cand_idx >= len(non_var_children):
if len(other_leftovers) == 0: for match in match_plain_var_candidates(urec, other_leftovers):
# Only return records that are fully matched. yield match
yield urec
return return
from six import iteritems for other_idx, pair_urecs in unification_candidates[next_cand_idx]:
for other_idxs, pair_urecs in iteritems( if other_idx not in other_leftovers:
unification_candidates[next_cand_idx]):
if not other_idxs <= other_leftovers:
# Don't re-match any elements. # Don't re-match any elements.
continue continue
new_urecs = unify_many(pair_urecs, urec) new_urecs = unify_many(pair_urecs, urec)
new_rhs_leftovers = other_leftovers - other_idxs new_rhs_leftovers = other_leftovers - set([other_idx])
for cand_urec in new_urecs: for cand_urec in new_urecs:
for result_urec in match_children( for result_urec in match_children(
cand_urec, next_cand_idx + 1, new_rhs_leftovers): cand_urec, next_cand_idx + 1, new_rhs_leftovers):
yield result_urec yield result_urec
def match_plain_var_candidates(urec, other_leftovers):
if len(plain_var_candidates) == len(other_leftovers) == 0:
yield urec
return
# At this point, the values in plain_var_candidates have not
# been matched in the lhs, and the values in other_leftovers
# have not been matched in the rhs. Try all possible
# combinations of matches (this part may become a performance
# bottleneck and if needed could be optimized further).
def subsets(s, max_size):
from itertools import combinations
for size in range(1, max_size + 1):
for subset in combinations(s, size):
yield subset
def partitions(s, k):
if k == 1:
yield [s]
return
for subset in map(set, subsets(s, len(s) - k + 1)):
for partition in partitions(s - subset, k - 1):
yield [subset] + partition
for partition in partitions(
other_leftovers, len(plain_var_candidates)):
result = urec
for subset, var in zip(partition, plain_var_candidates):
eqn = (var, factory(other.children[i] for i in subset))
result = result.unify(UnificationRecord([eqn]))
if not result:
break
else:
yield result
for urec in match_children( for urec in match_children(
UnificationRecord([]), 0, set(range(len(other.children)))): UnificationRecord([]), 0, set(range(len(other.children)))):
yield urec yield urec
......
...@@ -467,6 +467,11 @@ def test_unifier(): ...@@ -467,6 +467,11 @@ def test_unifier():
assert match_found(recs, set([(b, e), (a, d+f)])) assert match_found(recs, set([(b, e), (a, d+f)]))
assert match_found(recs, set([(b, f), (a, d+e)])) assert match_found(recs, set([(b, f), (a, d+e)]))
vals = [var("v" + str(i)) for i in range(100)]
recs = UnidirectionalUnifier("a")(sum(vals[1:]) + a, sum(vals))
assert len(recs) == 1
assert match_found(recs, set([(a, var("v0"))]))
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment