diff --git a/sumpy/cse.py b/sumpy/cse.py index b689786d1378503d2271b4fb2d65e916fadf9e49..350a2c8d97d6ef4729657c85248d412e4128f6a1 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -192,16 +192,19 @@ class FuncArgTracker(object): (k, v) for k, v in count_map.items() if v >= threshold) - def get_subset_candidates(self, argset, restrict_to_funcset=frozenset()): + def get_subset_candidates(self, argset, restrict_to_funcset=None): """ Return a set of functions each of which whose argument list contains - `argset.` + `argset`, optionally filtered only to contain functions in + `restrict_to_funcset`. """ iarg = iter(argset) indices = set( - fi for fi in self.arg_to_funcset[next(iarg)] - if fi in restrict_to_funcset) + fi for fi in self.arg_to_funcset[next(iarg)]) + + if restrict_to_funcset is not None: + indices &= restrict_to_funcset for arg in iarg: indices &= self.arg_to_funcset[arg]