From 6ca809273a9846522a1f0bdbcf97b02e7d1072f5 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992@gmail.com>
Date: Sat, 17 May 2014 14:46:05 -0400
Subject: [PATCH] add keyword argument support and key= keyword argument to
 function memoize in order to support arguments that are not directly hashable
 add test for pytools.memoize

---
 pytools/__init__.py  | 38 +++++++++++++++++++++++++++++++-------
 pytools/decorator.py | 16 ++++++++--------
 test/test_pytools.py | 38 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 77 insertions(+), 15 deletions(-)

diff --git a/pytools/__init__.py b/pytools/__init__.py
index 17550f6..4328b9e 100644
--- a/pytools/__init__.py
+++ b/pytools/__init__.py
@@ -369,24 +369,48 @@ def single_valued(iterable, equality_pred=operator.eq):
 
 # {{{ memoization / attribute storage
 
-@my_decorator
-def memoize(func, *args):
+def _memoize(func, key, args, kwargs):
     # by Michele Simionato
     # http://www.phyast.pitt.edu/~micheles/python/
 
     try:
-        return func._memoize_dic[args]
+        return func._memoize_dic[key]
     except AttributeError:
         # _memoize_dic doesn't exist yet.
 
-        result = func(*args)
-        func._memoize_dic = {args: result}
+        result = func(*args, **kwargs)
+        func._memoize_dic = {key: result}
         return result
     except KeyError:
-        result = func(*args)
-        func._memoize_dic[args] = result
+        result = func(*args, **kwargs)
+        func._memoize_dic[key] = result
         return result
 
+try:
+    dict.iteritems
+    def _iteritem(d):
+        return d.iteritems()
+except AttributeError:
+    def _iteritem(d):
+        return d.items()
+
+def memoize(*args, **kwargs):
+    key_func = kwargs.pop(
+        'key', lambda *a, **kw: (a, frozenset(_iteritem(kw))))
+    if kwargs:
+        raise TypeError(
+            "memorize recived unexpected keyword arguments: %s"
+            % ", ".join(kwargs.keys()))
+    @my_decorator
+    def _deco(func, *args, **kwargs):
+        return _memoize(func, key_func(*args, **kwargs), args, kwargs)
+    if not args:
+        return _deco
+    if callable(args[0]) and len(args) == 1:
+        return _deco(args[0])
+    raise TypeError(
+        "memorize recived unexpected position arguments: %s" % args)
+
 FunctionValueCache = memoize
 
 
diff --git a/pytools/decorator.py b/pytools/decorator.py
index cf8ead3..502da0f 100644
--- a/pytools/decorator.py
+++ b/pytools/decorator.py
@@ -10,7 +10,7 @@
 __all__ = ["decorator", "update_wrapper", "getinfo"]
 
 import inspect
-    
+
 def getinfo(func):
     """
     Returns an info dictionary containing:
@@ -21,7 +21,7 @@ def getinfo(func):
     - doc (the docstring : str)
     - module (the module name : str)
     - dict (the function __dict__ : str)
-    
+
     >>> def f(self, x=1, y=2, *args, **kw): pass
 
     >>> info = getinfo(f)
@@ -30,7 +30,7 @@ def getinfo(func):
     'f'
     >>> info["argnames"]
     ['self', 'x', 'y', 'args', 'kw']
-    
+
     >>> info["defaults"]
     (1, 2)
 
@@ -68,7 +68,7 @@ def update_wrapper(wrapper, wrapped, create=False):
     if create: # create a brand new wrapper with the right signature
         src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict
         # import sys; print >> sys.stderr, src # for debugging purposes
-        wrapper = eval(src, dict(_wrapper_=wrapper))        
+        wrapper = eval(src, dict(_wrapper_=wrapper))
     try:
         wrapper.__name__ = infodict['name']
     except: # Python version < 2.4
@@ -102,7 +102,7 @@ def decorator(caller, func=None):
      def caller(func, *args, **kw):
          # do something
          return func(*args, **kw)
-    
+
     Here is an example of usage:
 
     >>> @decorator
@@ -112,7 +112,7 @@ def decorator(caller, func=None):
 
     >>> chatty.__name__
     'chatty'
-    
+
     >>> @chatty
     ... def f(): pass
     ...
@@ -137,13 +137,13 @@ if __name__ == "__main__":
     import doctest; doctest.testmod()
 
 #######################     LEGALESE    ##################################
-      
+
 ##   Redistributions of source code must retain the above copyright 
 ##   notice, this list of conditions and the following disclaimer.
 ##   Redistributions in bytecode form must reproduce the above copyright
 ##   notice, this list of conditions and the following disclaimer in
 ##   the documentation and/or other materials provided with the
-##   distribution. 
+##   distribution.
 
 ##   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 ##   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
diff --git a/test/test_pytools.py b/test/test_pytools.py
index f61e0e4..c6f8a72 100644
--- a/test/test_pytools.py
+++ b/test/test_pytools.py
@@ -89,3 +89,41 @@ def test_p_convergence_verifier():
         pconv_verifier.add_data_point(order, 2)
     with pytest.raises(AssertionError):
         pconv_verifier()
+
+
+def test_memoize():
+    from pytools import memoize
+    count = [0]
+
+    @memoize
+    def f(i, j=1):
+        count[0] += 1
+        return i + j
+
+    assert f(1) == 2
+    assert f(1, 2) == 3
+    assert f(2, j=3) == 5
+    assert count[0] == 3
+    assert f(1) == 2
+    assert f(1, 2) == 3
+    assert f(2, j=3) == 5
+    assert count[0] == 3
+
+
+def test_memoize_keyfunc():
+    from pytools import memoize
+    count = [0]
+
+    @memoize(key=lambda i, j=(1,): (i, len(j)))
+    def f(i, j=(1,)):
+        count[0] += 1
+        return i + len(j)
+
+    assert f(1) == 2
+    assert f(1, [2]) == 2
+    assert f(2, j=[2, 3]) == 4
+    assert count[0] == 2
+    assert f(1) == 2
+    assert f(1, (2,)) == 2
+    assert f(2, j=(2, 3)) == 4
+    assert count[0] == 2
-- 
GitLab