diff --git a/pytools/__init__.py b/pytools/__init__.py index 6066a1e0656716eff74fe6fa2bc6f50846cfe46d..af84b1bc55a0564ee31fe15795352516b0602335 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -24,6 +24,7 @@ THE SOFTWARE. import operator +import sys from pytools.decorator import decorator @@ -351,7 +352,7 @@ def single_valued(iterable, equality_pred=operator.eq): # }}} -# {{{ memoization +# {{{ memoization / attribute storage @my_decorator def memoize(func, *args): @@ -370,22 +371,54 @@ def memoize(func, *args): result = func(*args) func._memoize_dic[args] = result return result + FunctionValueCache = memoize -@my_decorator -def memoize_method(method, instance, *args): - cache_dict_name = intern("_memoize_dic_"+method.__name__) - try: - return getattr(instance, cache_dict_name)[args] - except AttributeError: - result = method(instance, *args) - setattr(instance, cache_dict_name, {args: result}) - return result - except KeyError: - result = method(instance, *args) - getattr(instance, cache_dict_name)[args] = result - return result +if sys.version_info >= (2, 5): + # For Python 2.5 and newer, support cache deletion by a + # 'method_name.clear_cache(self)' call. + + def memoize_method(method): + cache_dict_name = intern("_memoize_dic_"+method.__name__) + + def wrapper(self, *args): + try: + return getattr(self, cache_dict_name)[args] + except AttributeError: + result = method(self, *args) + setattr(self, cache_dict_name, {args: result}) + return result + except KeyError: + result = method(self, *args) + getattr(self, cache_dict_name)[args] = result + return result + + def clear_cache(self): + delattr(self, cache_dict_name) + + from functools import update_wrapper + new_wrapper = update_wrapper(wrapper, method) + new_wrapper.clear_cache = clear_cache + + return new_wrapper + +else: + # For sad old Python 2.4, cache deletion is not supported. + + @my_decorator + def memoize_method(method, instance, *args): + cache_dict_name = intern("_memoize_dic_"+method.__name__) + try: + return getattr(instance, cache_dict_name)[args] + except AttributeError: + result = method(instance, *args) + setattr(instance, cache_dict_name, {args: result}) + return result + except KeyError: + result = method(instance, *args) + getattr(instance, cache_dict_name)[args] = result + return result def memoize_method_nested(inner): @@ -424,8 +457,6 @@ def memoize_method_nested(inner): return new_inner -FunctionValueCache = memoize - # }}} diff --git a/test/test_pytools.py b/test/test_pytools.py index 7507971ba4845df01e5051a114160bb996eee55c..f8d0109c11fb44801a4401642687222a400be626 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -1,5 +1,8 @@ from __future__ import division +import pytest +import sys # noqa + def test_memoize_method_nested(): from pytools import memoize_method_nested @@ -21,3 +24,24 @@ def test_memoize_method_nested(): sc = SomeClass() sc.f() assert sc.run_count == 1 + + +@pytest.mark.skipif("sys.version_info < (2, 5)") +def test_memoize_method_clear(): + from pytools import memoize_method + + class SomeClass: + def __init__(self): + self.run_count = 0 + + @memoize_method + def f(self): + self.run_count += 1 + return 17 + + sc = SomeClass() + sc.f() + sc.f() + assert sc.run_count == 1 + + sc.f.clear_cache(sc)