diff --git a/sumpy/cse.py b/sumpy/cse.py index 573c78c93961c8fe0e69c9650c4df7d10133e98b..823ef54667ff6c1b8895eab3d59e9b3bd5316f7a 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -178,44 +178,44 @@ class FuncArgTracker(object): for arg in self.func_to_argset[func_i]: self.arg_to_funcset[arg].remove(func_i) - def get_common_arg_candidates(self, argset, min_func_i, threshold=2): - """ - Return a dict whose keys are function numbers. The entries of the dict are - the number of arguments said function has in common with `argset`. Entries - have at least `threshold` items in common. + def get_common_arg_candidates(self, argset, min_func_i=0): + """Return a dict whose keys are function numbers. The entries of the dict are + the number of arguments said function has in common with + `argset`. Entries have at least 2 items in common. All keys have + value at least `min_func_i`. """ from collections import defaultdict count_map = defaultdict(lambda: 0) - # Sorted by size to make best use of the performance hack below. - funcsets = sorted((self.arg_to_funcset[arg] for arg in argset), key=len) + funcsets = [self.arg_to_funcset[arg] for arg in argset] + # As an optimization below, we handle the largest funcset separately from + # the others. + largest_funcset = max(funcsets, key=len) - for funcset in funcsets[:-threshold+1]: + for funcset in funcsets: + if largest_funcset is funcset: + continue for func_i in funcset: if func_i >= min_func_i: count_map[func_i] += 1 - for i, funcset in enumerate(funcsets[-threshold+1:]): - # When looking at the tail end of the funcsets list, items below - # this threshold in the count_map don't have to be considered - # because they can't possibly be in the output. - count_map_threshold = i + 1 - - # We pick the smaller of the two containers to iterate over to - # reduce the number of items we have to look at. - (smaller_funcs_container, - larger_funcs_container) = sorted([funcset, count_map], key=len) - - for func_i in smaller_funcs_container: - if count_map[func_i] < count_map_threshold: - continue + # We pick the smaller of the two containers (count_map, largest_funcset) + # to iterate over to reduce the number of iterations needed. + (smaller_funcs_container, + larger_funcs_container) = sorted( + [largest_funcset, count_map], + key=len) + + for func_i in smaller_funcs_container: + # Not already in count_map? It can't possibly be in the output, so + # skip it. + if count_map[func_i] < 1: + continue - if func_i in larger_funcs_container: - count_map[func_i] += 1 + if func_i in larger_funcs_container: + count_map[func_i] += 1 - return dict( - (k, v) for k, v in count_map.items() - if v >= threshold) + return dict((k, v) for k, v in count_map.items() if v >= 2) def get_subset_candidates(self, argset, restrict_to_funcset=None): """ @@ -297,7 +297,7 @@ def match_common_args(func_class, funcs, opt_subs): for i in range(len(funcs)): common_arg_candidates_counts = arg_tracker.get_common_arg_candidates( - arg_tracker.func_to_argset[i], i + 1, threshold=2) + arg_tracker.func_to_argset[i], min_func_i=i + 1) # Sort the candidates in order of match size. # This makes us try combining smaller matches first.