diff --git a/pytools/__init__.py b/pytools/__init__.py index 17550f668cba40a26aafeb908925d22b248fb44e..4328b9e6afc8c6f20932c7b2e56a3a2be2ec904c 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -369,24 +369,48 @@ def single_valued(iterable, equality_pred=operator.eq): # {{{ memoization / attribute storage -@my_decorator -def memoize(func, *args): +def _memoize(func, key, args, kwargs): # by Michele Simionato # http://www.phyast.pitt.edu/~micheles/python/ try: - return func._memoize_dic[args] + return func._memoize_dic[key] except AttributeError: # _memoize_dic doesn't exist yet. - result = func(*args) - func._memoize_dic = {args: result} + result = func(*args, **kwargs) + func._memoize_dic = {key: result} return result except KeyError: - result = func(*args) - func._memoize_dic[args] = result + result = func(*args, **kwargs) + func._memoize_dic[key] = result return result +try: + dict.iteritems + def _iteritem(d): + return d.iteritems() +except AttributeError: + def _iteritem(d): + return d.items() + +def memoize(*args, **kwargs): + key_func = kwargs.pop( + 'key', lambda *a, **kw: (a, frozenset(_iteritem(kw)))) + if kwargs: + raise TypeError( + "memorize recived unexpected keyword arguments: %s" + % ", ".join(kwargs.keys())) + @my_decorator + def _deco(func, *args, **kwargs): + return _memoize(func, key_func(*args, **kwargs), args, kwargs) + if not args: + return _deco + if callable(args[0]) and len(args) == 1: + return _deco(args[0]) + raise TypeError( + "memorize recived unexpected position arguments: %s" % args) + FunctionValueCache = memoize diff --git a/pytools/decorator.py b/pytools/decorator.py index cf8ead36a4208e48878530540fae075f6369e635..502da0f467a24866d655ad6cc9c0c52ac4819de5 100644 --- a/pytools/decorator.py +++ b/pytools/decorator.py @@ -10,7 +10,7 @@ __all__ = ["decorator", "update_wrapper", "getinfo"] import inspect - + def getinfo(func): """ Returns an info dictionary containing: @@ -21,7 +21,7 @@ def getinfo(func): - doc (the docstring : str) - module (the module name : str) - dict (the function __dict__ : str) - + >>> def f(self, x=1, y=2, *args, **kw): pass >>> info = getinfo(f) @@ -30,7 +30,7 @@ def getinfo(func): 'f' >>> info["argnames"] ['self', 'x', 'y', 'args', 'kw'] - + >>> info["defaults"] (1, 2) @@ -68,7 +68,7 @@ def update_wrapper(wrapper, wrapped, create=False): if create: # create a brand new wrapper with the right signature src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict # import sys; print >> sys.stderr, src # for debugging purposes - wrapper = eval(src, dict(_wrapper_=wrapper)) + wrapper = eval(src, dict(_wrapper_=wrapper)) try: wrapper.__name__ = infodict['name'] except: # Python version < 2.4 @@ -102,7 +102,7 @@ def decorator(caller, func=None): def caller(func, *args, **kw): # do something return func(*args, **kw) - + Here is an example of usage: >>> @decorator @@ -112,7 +112,7 @@ def decorator(caller, func=None): >>> chatty.__name__ 'chatty' - + >>> @chatty ... def f(): pass ... @@ -137,13 +137,13 @@ if __name__ == "__main__": import doctest; doctest.testmod() ####################### LEGALESE ################################## - + ## Redistributions of source code must retain the above copyright ## notice, this list of conditions and the following disclaimer. ## Redistributions in bytecode form must reproduce the above copyright ## notice, this list of conditions and the following disclaimer in ## the documentation and/or other materials provided with the -## distribution. +## distribution. ## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT diff --git a/test/test_pytools.py b/test/test_pytools.py index f61e0e44bf3ba58fad945992e25afa7acc4315bb..c6f8a72395abf49ec6666fc94d0651e378596f20 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -89,3 +89,41 @@ def test_p_convergence_verifier(): pconv_verifier.add_data_point(order, 2) with pytest.raises(AssertionError): pconv_verifier() + + +def test_memoize(): + from pytools import memoize + count = [0] + + @memoize + def f(i, j=1): + count[0] += 1 + return i + j + + assert f(1) == 2 + assert f(1, 2) == 3 + assert f(2, j=3) == 5 + assert count[0] == 3 + assert f(1) == 2 + assert f(1, 2) == 3 + assert f(2, j=3) == 5 + assert count[0] == 3 + + +def test_memoize_keyfunc(): + from pytools import memoize + count = [0] + + @memoize(key=lambda i, j=(1,): (i, len(j))) + def f(i, j=(1,)): + count[0] += 1 + return i + len(j) + + assert f(1) == 2 + assert f(1, [2]) == 2 + assert f(2, j=[2, 3]) == 4 + assert count[0] == 2 + assert f(1) == 2 + assert f(1, (2,)) == 2 + assert f(2, j=(2, 3)) == 4 + assert count[0] == 2