diff --git a/pycuda/tools.py b/pycuda/tools.py index 1a2b50f4dd786628a5721b98f01accab013bd875..05ac3c52ec3e2e998c6e9be4337d1b8dfa1f04df 100644 --- a/pycuda/tools.py +++ b/pycuda/tools.py @@ -26,7 +26,6 @@ OTHER DEALINGS IN THE SOFTWARE. """ import pycuda.driver as cuda -from decorator import decorator import pycuda._driver as _drv import numpy as np @@ -451,25 +450,34 @@ def get_arg_type(c_arg): context_dependent_memoized_functions = [] -@decorator -def context_dependent_memoize(func, *args): - try: - ctx_dict = func._pycuda_ctx_dep_memoize_dic - except AttributeError: - # FIXME: This may keep contexts alive longer than desired. - # But I guess since the memory in them is freed, who cares. - ctx_dict = func._pycuda_ctx_dep_memoize_dic = {} +def context_dependent_memoize(func): + def wrapper(*args, **kwargs): + if kwargs: + cache_key = (args, frozenset(kwargs.items())) + else: + cache_key = (args,) - cur_ctx = cuda.Context.get_current() + try: + ctx_dict = func._pycuda_ctx_dep_memoize_dic + except AttributeError: + # FIXME: This may keep contexts alive longer than desired. + # But I guess since the memory in them is freed, who cares. + ctx_dict = func._pycuda_ctx_dep_memoize_dic = {} - try: - return ctx_dict[cur_ctx][args] - except KeyError: - context_dependent_memoized_functions.append(func) - arg_dict = ctx_dict.setdefault(cur_ctx, {}) - result = func(*args) - arg_dict[args] = result - return result + cur_ctx = cuda.Context.get_current() + + try: + return ctx_dict[cur_ctx][cache_key] + except KeyError: + context_dependent_memoized_functions.append(func) + arg_dict = ctx_dict.setdefault(cur_ctx, {}) + result = func(*args, **kwargs) + arg_dict[cache_key] = result + return result + + from functools import update_wrapper + update_wrapper(wrapper, func) + return wrapper def clear_context_caches(): diff --git a/setup.py b/setup.py index 990a7830dbdadf77a41ce4e2f3b1cebc81a9bc03..44545e16e71e5d2f6a17c7f6e1207abc63464da6 100644 --- a/setup.py +++ b/setup.py @@ -225,7 +225,6 @@ def main(): python_requires="~=3.6", install_requires=[ "pytools>=2011.2", - "decorator>=3.2.0", "appdirs>=1.4.0", "mako", ],