From 08caf3a1d88e2edb639fb504182219511624247a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 28 Oct 2013 19:46:32 -0500 Subject: [PATCH] Add memoize_method_with_uncached --- pytools/__init__.py | 62 +++++++++++++++++++++++++++++++++++++++++++- test/test_pytools.py | 56 +++++++++++++++++++++++++++------------ 2 files changed, 100 insertions(+), 18 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 5c2828e..846e042 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -378,8 +378,13 @@ FunctionValueCache = memoize class _HasKwargs(object): pass + def memoize_method(method): - """Supports cache deletion via ``method_name.clear_cache(self)``.""" + """Supports cache deletion via ``method_name.clear_cache(self)``. + + .. note:: + *clear_cache* support requires Python 2.5 or newer. + """ cache_dict_name = intern("_memoize_dic_"+method.__name__) @@ -411,6 +416,61 @@ def memoize_method(method): return new_wrapper +def memoize_method_with_uncached(uncached_args=[], uncached_kwargs=set()): + """Supports cache deletion via ``method_name.clear_cache(self)``. + + :arg uncached_args: a list of argument numbers + (0-based, not counting 'self' argument) + """ + + # delete starting from the end + uncached_args = sorted(uncached_args, reverse=True) + uncached_kwargs = list(uncached_kwargs) + + def parametrized_decorator(method): + cache_dict_name = intern("_memoize_dic_"+method.__name__) + + def wrapper(self, *args, **kwargs): + cache_args = list(args) + cache_kwargs = kwargs.copy() + + for i in uncached_args: + if i < len(cache_args): + cache_args.pop(i) + + cache_args = tuple(cache_args) + + if kwargs: + for name in uncached_kwargs: + cache_kwargs.pop(name, None) + + key = (_HasKwargs, frozenset(cache_kwargs.iteritems())) + cache_args + else: + key = cache_args + + try: + return getattr(self, cache_dict_name)[key] + except AttributeError: + result = method(self, *args, **kwargs) + setattr(self, cache_dict_name, {key: result}) + return result + except KeyError: + result = method(self, *args, **kwargs) + getattr(self, cache_dict_name)[key] = result + return result + + def clear_cache(self): + delattr(self, cache_dict_name) + + if sys.version_info >= (2, 5): + from functools import update_wrapper + new_wrapper = update_wrapper(wrapper, method) + new_wrapper.clear_cache = clear_cache + + return new_wrapper + + return parametrized_decorator + def memoize_method_nested(inner): """Adds a cache to a function nested inside a method. The cache is attached diff --git a/test/test_pytools.py b/test/test_pytools.py index 7f0e3f4..7cecf68 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -4,49 +4,71 @@ import pytest import sys # noqa -def test_memoize_method_nested(): - from pytools import memoize_method_nested +@pytest.mark.skipif("sys.version_info < (2, 5)") +def test_memoize_method_clear(): + from pytools import memoize_method class SomeClass: def __init__(self): self.run_count = 0 + @memoize_method def f(self): - - @memoize_method_nested - def inner(x): - self.run_count += 1 - return 2*x - - inner(5) - inner(5) + self.run_count += 1 + return 17 sc = SomeClass() sc.f() + sc.f() assert sc.run_count == 1 + sc.f.clear_cache(sc) + -@pytest.mark.skipif("sys.version_info < (2, 5)") -def test_memoize_method_clear(): - from pytools import memoize_method +def test_memoize_method_with_uncached(): + from pytools import memoize_method_with_uncached class SomeClass: def __init__(self): self.run_count = 0 - @memoize_method - def f(self): + @memoize_method_with_uncached(uncached_args=[1], uncached_kwargs=["z"]) + def f(self, x, y, z): self.run_count += 1 return 17 sc = SomeClass() - sc.f() - sc.f() + sc.f(17, 18, z=19) + sc.f(17, 19, z=20) assert sc.run_count == 1 + sc.f(18, 19, z=20) + assert sc.run_count == 2 sc.f.clear_cache(sc) +def test_memoize_method_nested(): + from pytools import memoize_method_nested + + class SomeClass: + def __init__(self): + self.run_count = 0 + + def f(self): + + @memoize_method_nested + def inner(x): + self.run_count += 1 + return 2*x + + inner(5) + inner(5) + + sc = SomeClass() + sc.f() + assert sc.run_count == 1 + + def test_p_convergence_verifier(): from pytools.convergence import PConvergenceVerifier -- GitLab