From 1bce26ffaf14cf67cdd3e9a1e3b8949ca279fcb8 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 19 Jan 2017 10:15:35 -0600 Subject: [PATCH] CSE: Fix recursive matching. This makes match_common_args() slightly better at detecting recursive matches. Closes #9 --- sumpy/cse.py | 87 ++++++++++++++++++++++++++++++++---------------- sumpy/tools.py | 70 ++++++++++++++++++++++++++++++++++++++ test/test_cse.py | 17 +++++++++- 3 files changed, 144 insertions(+), 30 deletions(-) diff --git a/sumpy/cse.py b/sumpy/cse.py index ca82b971..b689786d 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -174,10 +174,11 @@ class FuncArgTracker(object): for arg in self.func_to_argset[func_i]: self.arg_to_funcset[arg].remove(func_i) - def gen_common_arg_candidates(self, argset, min_func_i, threshold=2): + def get_common_arg_candidates(self, argset, min_func_i, threshold=2): """ - Generate the list of functions which have at least `threshold` arguments in - common from `argset`. + Return a dict whose keys are function numbers. The entries of the dict are + the number of arguments said function has in common with `argset`. Entries + have at least `threshold` items in common. """ from collections import defaultdict count_map = defaultdict(lambda: 0) @@ -187,28 +188,25 @@ class FuncArgTracker(object): if func_i >= min_func_i: count_map[func_i] += 1 - count_map_keys_in_order = sorted( - key for key, val in count_map.items() - if val >= threshold) + return dict( + (k, v) for k, v in count_map.items() + if v >= threshold) - for item in count_map_keys_in_order: - yield item - - def gen_subset_candidates(self, argset, min_func_i): + def get_subset_candidates(self, argset, restrict_to_funcset=frozenset()): """ - Generate the list of functions whose set of arguments contains `argset`. + Return a set of functions each of which whose argument list contains + `argset.` """ iarg = iter(argset) indices = set( fi for fi in self.arg_to_funcset[next(iarg)] - if fi >= min_func_i) + if fi in restrict_to_funcset) for arg in iarg: indices &= self.arg_to_funcset[arg] - for item in indices: - yield item + return indices def update_func_argset(self, func_i, new_argset): """ @@ -246,13 +244,28 @@ def match_common_args(func_class, funcs, opt_subs): :arg funcs: A list of function calls :arg opt_subs: A dictionary of substitutions which this function may update """ + + # Sort to ensure that whole-function subexpressions come before the items + # that use them. + funcs = sorted(funcs, key=lambda f: len(f.args)) arg_tracker = FuncArgTracker(funcs) changed = set() + from sumpy.tools import OrderedSet + for i in range(len(funcs)): - for j in arg_tracker.gen_common_arg_candidates( - arg_tracker.func_to_argset[i], i + 1, threshold=2): + common_arg_candidates = arg_tracker.get_common_arg_candidates( + arg_tracker.func_to_argset[i], i + 1, threshold=2) + + # Sort the candidates in order of match size. + # This makes us try combining smaller matches first. + common_arg_candidates = OrderedSet(sorted( + common_arg_candidates.keys(), + key=lambda k: (common_arg_candidates[k], k))) + + while common_arg_candidates: + j = common_arg_candidates.pop(last=False) com_args = arg_tracker.func_to_argset[i].intersection( arg_tracker.func_to_argset[j]) @@ -262,21 +275,36 @@ def match_common_args(func_class, funcs, opt_subs): # combined in a previous iteration. continue - com_func = func_class(*arg_tracker.get_args_in_value_order(com_args)) - com_func_number = arg_tracker.get_or_add_value_number(com_func) - # For all sets, replace the common symbols by the function # over them, to allow recursive matches. diff_i = arg_tracker.func_to_argset[i].difference(com_args) - arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) - changed.add(i) + if diff_i: + # com_func needs to be unevaluated to allow for recursive matches. + com_func = func_class( + *arg_tracker.get_args_in_value_order(com_args), + evaluate=False) + com_func_number = arg_tracker.get_or_add_value_number(com_func) + arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) + changed.add(i) + else: + # Treat the whole expression as a CSE. + # + # The reason this needs to be done is somewhat subtle. Within + # tree_cse(), to_eliminate only contains expressions that are + # seen more than once. The problem is unevaluated expressions + # do not compare equal to the evaluated equivalent. So + # tree_cse() won't mark funcs[i] as a CSE if we use an + # unevaluated version. + com_func = funcs[i] + com_func_number = arg_tracker.get_or_add_value_number(funcs[i]) diff_j = arg_tracker.func_to_argset[j].difference(com_args) arg_tracker.update_func_argset(j, diff_j | set([com_func_number])) changed.add(j) - for k in arg_tracker.gen_subset_candidates(com_args, j + 1): + for k in arg_tracker.get_subset_candidates( + com_args, common_arg_candidates): diff_k = arg_tracker.func_to_argset[k].difference(com_args) arg_tracker.update_func_argset(k, diff_k | set([com_func_number])) changed.add(k) @@ -299,15 +327,15 @@ def opt_cse(exprs): """ opt_subs = dict() - adds = [] - muls = [] + from sumpy.tools import OrderedSet + adds = OrderedSet() + muls = OrderedSet() seen_subexp = set() # {{{ look for optimization opportunities, clean up minus signs def find_opts(expr): - if not isinstance(expr, Basic): return @@ -335,10 +363,10 @@ def opt_cse(exprs): expr = neg_expr if isinstance(expr, Mul): - muls.append(expr) + muls.add(expr) elif isinstance(expr, Add): - adds.append(expr) + adds.add(expr) elif isinstance(expr, Pow): base, exp = expr.args @@ -351,8 +379,9 @@ def opt_cse(exprs): if isinstance(e, Basic): find_opts(e) - match_common_args(Add, adds, opt_subs) - match_common_args(Mul, muls, opt_subs) + match_common_args(Add, list(adds), opt_subs) + match_common_args(Mul, list(muls), opt_subs) + return opt_subs # }}} diff --git a/sumpy/tools.py b/sumpy/tools.py index 5e0d4341..60687510 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -279,6 +279,76 @@ class KernelComputation(object): # }}} +# {{{ OrderedSet + +# Source: http://code.activestate.com/recipes/576694-orderedset/ +# Author: Raymond Hettinger +# License: MIT + +import collections + + +class OrderedSet(collections.MutableSet): + + def __init__(self, iterable=None): + self.end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] + if iterable is not None: + self |= iterable + + def __len__(self): + return len(self.map) + + def __contains__(self, key): + return key in self.map + + def add(self, key): + if key not in self.map: + end = self.end + curr = end[1] + curr[2] = end[1] = self.map[key] = [key, curr, end] + + def discard(self, key): + if key in self.map: + key, prev, next = self.map.pop(key) + prev[2] = next + next[1] = prev + + def __iter__(self): + end = self.end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __reversed__(self): + end = self.end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def pop(self, last=True): + if not self: + raise KeyError('set is empty') + key = self.end[1][0] if last else self.end[2][0] + self.discard(key) + return key + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, list(self)) + + def __eq__(self, other): + if isinstance(other, OrderedSet): + return len(self) == len(other) and list(self) == list(other) + return set(self) == set(other) + +# }}} + + class KernelCacheWrapper(object): @memoize_method def get_cached_optimized_kernel(self, **kwargs): diff --git a/test/test_cse.py b/test/test_cse.py index b5697ac4..d2904176 100644 --- a/test/test_cse.py +++ b/test/test_cse.py @@ -253,7 +253,7 @@ def test_issue_4499(): sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b, -2*a)) # noqa c = cse(t) - assert len(c[0]) == 13 + assert len(c[0]) == 11 def test_issue_6169(): @@ -345,6 +345,21 @@ def test_issue_7840(): assert len(substitutions) < 1 +def test_recursive_matching(): + assert cse([x+y, 2+x+y, x+y+z, 3+x+y+z]) == \ + ([(x0, x + y), (x1, x0 + z)], [x0, x0 + 2, x1, x1 + 3]) + assert cse(reversed([x+y, 2+x+y, x+y+z, 3+x+y+z])) == \ + ([(x0, x + y), (x1, x0 + z)], [x1 + 3, x1, x0 + 2, x0]) + # sympy 1.0 gives ([(x0, x*y*z)], [5*x0, w*(x*y), 3*x0]) + assert cse([x*y*z*5, x*y*w, x*y*z*3]) == \ + ([(x0, x*y), (x1, x0*z)], [5*x1, w*x0, 3*x1]) + # sympy 1.0 gives ([(x4, x*y*z)], [5*x4, w*x3*x4, 3*x*x0*x1*x2*y]) + assert cse([x*y*z*5, x*y*z*w*x3, x*y*3*x0*x1*x2]) == \ + ([(x4, x*y), (x5, x4*z)], [5*x5, w*x3*x5, 3*x0*x1*x2*x4]) + assert cse([2*x*x, x*x*y, x*x*y*w, x*x*y*w*x0, x*x*y*w*x2]) == \ + ([(x1, x**2), (x3, x1*y), (x4, w*x3)], [2*x1, x3, x4, x0*x4, x2*x4]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab