diff --git a/sumpy/cse.py b/sumpy/cse.py index 6a0270e7a31872127dbb0655653a128a5cd5b42c..3318fba023eccbc7673fab31db5d31b0cfeb4035 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -196,16 +196,19 @@ class FuncArgTracker(object): (k, v) for k, v in count_map.items() if v >= threshold) - def get_subset_candidates(self, argset, restrict_to_funcset=frozenset()): + def get_subset_candidates(self, argset, restrict_to_funcset=None): """ Return a set of functions each of which whose argument list contains - `argset.` + `argset`, optionally filtered only to contain functions in + `restrict_to_funcset`. """ iarg = iter(argset) indices = set( - fi for fi in self.arg_to_funcset[next(iarg)] - if fi in restrict_to_funcset) + fi for fi in self.arg_to_funcset[next(iarg)]) + + if restrict_to_funcset is not None: + indices &= restrict_to_funcset for arg in iarg: indices &= self.arg_to_funcset[arg]