From 9ff594e572f539d1e542ed3bdb67e2c64260a4e8 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Thu, 24 Aug 2017 03:50:51 -0500 Subject: [PATCH 1/2] get_common_arg_candidates(): Avoid calling sort(). This code is the hot spot for CSE and so it is better to avoid any needlessly expensive function calls. --- sumpy/cse.py | 57 ++++++++++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/sumpy/cse.py b/sumpy/cse.py index 573c78c9..53d04b67 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -178,44 +178,44 @@ class FuncArgTracker(object): for arg in self.func_to_argset[func_i]: self.arg_to_funcset[arg].remove(func_i) - def get_common_arg_candidates(self, argset, min_func_i, threshold=2): - """ - Return a dict whose keys are function numbers. The entries of the dict are - the number of arguments said function has in common with `argset`. Entries - have at least `threshold` items in common. + def get_common_arg_candidates(self, argset, min_func_i=0): + """Return a dict whose keys are function numbers. The entries of the dict are + the number of arguments said function has in common with + `argset`. Entries have at least 2 items in common. All keys have + value at least `min_func_i`. """ from collections import defaultdict count_map = defaultdict(lambda: 0) - # Sorted by size to make best use of the performance hack below. - funcsets = sorted((self.arg_to_funcset[arg] for arg in argset), key=len) + funcsets = [self.arg_to_funcset[arg] for arg in argset] + # As an optimization below, we handle the largest funcset separately from + # the others. + largest_funcset = max(funcsets, key=len) - for funcset in funcsets[:-threshold+1]: + for funcset in funcsets: + if largest_funcset is funcset: + continue for func_i in funcset: if func_i >= min_func_i: count_map[func_i] += 1 - for i, funcset in enumerate(funcsets[-threshold+1:]): - # When looking at the tail end of the funcsets list, items below - # this threshold in the count_map don't have to be considered - # because they can't possibly be in the output. - count_map_threshold = i + 1 - - # We pick the smaller of the two containers to iterate over to - # reduce the number of items we have to look at. - (smaller_funcs_container, - larger_funcs_container) = sorted([funcset, count_map], key=len) - - for func_i in smaller_funcs_container: - if count_map[func_i] < count_map_threshold: - continue + # We pick the smaller of the two containers (count_map, largest_funcset) + # to iterate over to reduce the number of iterations needed. + (smaller_funcs_container, + larger_funcs_container) = sorted( + [largest_funcset, count_map], + key=len) + + for func_i in smaller_funcs_container: + # Not already in count_map? It can't possibly be in the output, so + # skip it. + if count_map[func_i] < 1: + continue - if func_i in larger_funcs_container: - count_map[func_i] += 1 + if func_i in larger_funcs_container: + count_map[func_i] += 1 - return dict( - (k, v) for k, v in count_map.items() - if v >= threshold) + return dict((k, v) for k, v in count_map.items() if v >= 2) def get_subset_candidates(self, argset, restrict_to_funcset=None): """ @@ -297,7 +297,7 @@ def match_common_args(func_class, funcs, opt_subs): for i in range(len(funcs)): common_arg_candidates_counts = arg_tracker.get_common_arg_candidates( - arg_tracker.func_to_argset[i], i + 1, threshold=2) + arg_tracker.func_to_argset[i], min_func_i=i + 1) # Sort the candidates in order of match size. # This makes us try combining smaller matches first. @@ -375,6 +375,7 @@ def opt_cse(exprs): # {{{ look for optimization opportunities, clean up minus signs def find_opts(expr): + print("EXPR IS", expr) if not isinstance(expr, Basic): return -- GitLab From efde3e42d78e192c305dbf2d08663c611b0b7983 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Mon, 28 Aug 2017 16:35:39 -0500 Subject: [PATCH 2/2] Remove debugging statement. --- sumpy/cse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sumpy/cse.py b/sumpy/cse.py index 53d04b67..823ef546 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -375,7 +375,6 @@ def opt_cse(exprs): # {{{ look for optimization opportunities, clean up minus signs def find_opts(expr): - print("EXPR IS", expr) if not isinstance(expr, Basic): return -- GitLab