diff --git a/sumpy/cse.py b/sumpy/cse.py index 350a2c8d97d6ef4729657c85248d412e4128f6a1..ad44fc41c8516dfc90587ce61cab6a3e5070e6ea 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -183,11 +183,32 @@ class FuncArgTracker(object): from collections import defaultdict count_map = defaultdict(lambda: 0) - for arg in argset: - for func_i in self.arg_to_funcset[arg]: + # Sorted by size to make best use of the performance hack below. + funcsets = sorted((self.arg_to_funcset[arg] for arg in argset), key=len) + + for funcset in funcsets[:-threshold+1]: + for func_i in funcset: if func_i >= min_func_i: count_map[func_i] += 1 + for i, funcset in enumerate(funcsets[-threshold+1:]): + # When looking at the tail end of the funcsets list, items below + # this threshold in the count_map don't have to be considered + # because they can't possibly be in the output. + count_map_threshold = i + 1 + + # We pick the smaller of the two containers to iterate over to + # reduce the number of items we have to look at. + (smaller_funcs_container, + larger_funcs_container) = sorted([funcset, count_map], key=len) + + for func_i in smaller_funcs_container: + if count_map[func_i] < count_map_threshold: + continue + + if func_i in larger_funcs_container: + count_map[func_i] += 1 + return dict( (k, v) for k, v in count_map.items() if v >= threshold) @@ -258,14 +279,14 @@ def match_common_args(func_class, funcs, opt_subs): from sumpy.tools import OrderedSet for i in range(len(funcs)): - common_arg_candidates = arg_tracker.get_common_arg_candidates( + common_arg_candidates_counts = arg_tracker.get_common_arg_candidates( arg_tracker.func_to_argset[i], i + 1, threshold=2) # Sort the candidates in order of match size. # This makes us try combining smaller matches first. common_arg_candidates = OrderedSet(sorted( - common_arg_candidates.keys(), - key=lambda k: (common_arg_candidates[k], k))) + common_arg_candidates_counts.keys(), + key=lambda k: (common_arg_candidates_counts[k], k))) while common_arg_candidates: j = common_arg_candidates.pop(last=False)