diff --git a/sumpy/cse.py b/sumpy/cse.py
index 350a2c8d97d6ef4729657c85248d412e4128f6a1..ad44fc41c8516dfc90587ce61cab6a3e5070e6ea 100644
--- a/sumpy/cse.py
+++ b/sumpy/cse.py
@@ -183,11 +183,32 @@ class FuncArgTracker(object):
         from collections import defaultdict
         count_map = defaultdict(lambda: 0)
 
-        for arg in argset:
-            for func_i in self.arg_to_funcset[arg]:
+        # Sorted by size to make best use of the performance hack below.
+        funcsets = sorted((self.arg_to_funcset[arg] for arg in argset), key=len)
+
+        for funcset in funcsets[:-threshold+1]:
+            for func_i in funcset:
                 if func_i >= min_func_i:
                     count_map[func_i] += 1
 
+        for i, funcset in enumerate(funcsets[-threshold+1:]):
+            # When looking at the tail end of the funcsets list, items below
+            # this threshold in the count_map don't have to be considered
+            # because they can't possibly be in the output.
+            count_map_threshold = i + 1
+
+            # We pick the smaller of the two containers to iterate over to
+            # reduce the number of items we have to look at.
+            (smaller_funcs_container,
+             larger_funcs_container) = sorted([funcset, count_map], key=len)
+
+            for func_i in smaller_funcs_container:
+                if count_map[func_i] < count_map_threshold:
+                    continue
+
+                if func_i in larger_funcs_container:
+                    count_map[func_i] += 1
+
         return dict(
             (k, v) for k, v in count_map.items()
             if v >= threshold)
@@ -258,14 +279,14 @@ def match_common_args(func_class, funcs, opt_subs):
     from sumpy.tools import OrderedSet
 
     for i in range(len(funcs)):
-        common_arg_candidates = arg_tracker.get_common_arg_candidates(
+        common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
                 arg_tracker.func_to_argset[i], i + 1, threshold=2)
 
         # Sort the candidates in order of match size.
         # This makes us try combining smaller matches first.
         common_arg_candidates = OrderedSet(sorted(
-                common_arg_candidates.keys(),
-                key=lambda k: (common_arg_candidates[k], k)))
+                common_arg_candidates_counts.keys(),
+                key=lambda k: (common_arg_candidates_counts[k], k)))
 
         while common_arg_candidates:
             j = common_arg_candidates.pop(last=False)