diff --git a/pytools/__init__.py b/pytools/__init__.py index bb3412a41007b6bca179fa47b8100cdc8f3658f8..8be65a62d6a1f72a2630090fff33e88cac6fb392 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -375,21 +375,29 @@ def memoize(func, *args): FunctionValueCache = memoize +class _HasKwargs(object): + pass + def memoize_method(method): """Supports cache deletion via ``method_name.clear_cache(self)``.""" cache_dict_name = intern("_memoize_dic_"+method.__name__) - def wrapper(self, *args): + def wrapper(self, *args, **kwargs): + if kwargs: + key = (_HasKwargs, frozenset(kwargs.iteritems())) + args + else: + key = args + try: - return getattr(self, cache_dict_name)[args] + return getattr(self, cache_dict_name)[key] except AttributeError: - result = method(self, *args) - setattr(self, cache_dict_name, {args: result}) + result = method(self, *args, **kwargs) + setattr(self, cache_dict_name, {key: result}) return result except KeyError: - result = method(self, *args) - getattr(self, cache_dict_name)[args] = result + result = method(self, *args, **kwargs) + getattr(self, cache_dict_name)[key] = result return result def clear_cache(self):