diff --git a/pytools/importlib_backport.py b/pytools/importlib_backport.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5275c3b542fbbe3fcf1ba7fb9c216f3c9aa110 --- /dev/null +++ b/pytools/importlib_backport.py @@ -0,0 +1,94 @@ +"""Backport of importlib.import_module from 3.x. + +Downloaded from: https://github.com/sprintly/importlib + +This code is based in the implementation of importlib.import_module() +in Python 2.7. The license is below. + +======================================================================== + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF +hereby grants Licensee a nonexclusive, royalty-free, world-wide +license to reproduce, analyze, test, perform and/or display publicly, +prepare derivative works, distribute, and otherwise use Python +alone or in any derivative version, provided, however, that PSF's +License Agreement and PSF's notice of copyright, i.e., "Copyright (c) +2001, 2002, 2003, 2004, 2005, 2006 Python Software Foundation; All Rights +Reserved" are retained in Python alone or in any derivative version +prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + +""" + +# While not critical (and in no way guaranteed!), it would be nice to keep this +# code compatible with Python 2.3. +import sys +import six + + +def _resolve_name(name, package, level): + """Return the absolute name of the module to be imported.""" + if not hasattr(package, 'rindex'): + raise ValueError("'package' not set to a string") + dot = len(package) + for x in six.moves.xrange(level, 1, -1): + try: + dot = package.rindex('.', 0, dot) + except ValueError: + raise ValueError("attempted relative import beyond top-level " + "package") + return "%s.%s" % (package[:dot], name) + + +def import_module(name, package=None): + """Import a module. + The 'package' argument is required when performing a relative import. It + specifies the package to use as the anchor point from which to resolve the + relative import to an absolute import. + """ + if name.startswith('.'): + if not package: + raise TypeError("relative imports require the 'package' argument") + level = 0 + for character in name: + if character != '.': + break + level += 1 + name = _resolve_name(name[level:], package, level) + __import__(name) + return sys.modules[name] diff --git a/pytools/py_codegen.py b/pytools/py_codegen.py index a63a928acf10593c6bfa02801e8a7693386f29e9..2adb0aaa031a4b7155c3093b8309b1edebdf1a8e 100644 --- a/pytools/py_codegen.py +++ b/pytools/py_codegen.py @@ -100,6 +100,10 @@ class PythonFunctionGenerator(PythonCodeGenerator): def get_function(self): return self.get_module()[self.name] + def get_picklable_function(self): + module = self.get_picklable_module() + return PicklableFunction(module, self.name) + # {{{ pickling of binaries for generated code @@ -121,27 +125,32 @@ class PicklableModule(object): nondefault_globals = {} functions = {} + modules = {} - from types import FunctionType + from types import FunctionType, ModuleType for k, v in six.iteritems(self.mod_globals): if isinstance(v, FunctionType): functions[k] = ( v.__name__, marshal.dumps(v.__code__), v.__defaults__) - + elif isinstance(v, ModuleType): + modules[k] = v.__name__ elif k not in _empty_module_dict: nondefault_globals[k] = v import imp - return (0, imp.get_magic(), functions, nondefault_globals) + return (1, imp.get_magic(), functions, modules, nondefault_globals) def __setstate__(self, obj): v = obj[0] if v == 0: magic, functions, nondefault_globals = obj[1:] + modules = {} + elif v == 1: + magic, functions, modules, nondefault_globals = obj[1:] else: - raise ValueError("unknown version of PicklableGeneratedFunction") + raise ValueError("unknown version of PicklableModule") import imp if magic != imp.get_magic(): @@ -155,6 +164,11 @@ class PicklableModule(object): mod_globals.update(nondefault_globals) self.mod_globals = mod_globals + from pytools.importlib_backport import import_module + + for k, mod_name in six.iteritems(modules): + mod_globals[k] = import_module(mod_name) + from types import FunctionType for k, v in six.iteritems(functions): name, code_bytes, argdefs = v @@ -165,6 +179,32 @@ class PicklableModule(object): # }}} +# {{{ picklable function + +class PicklableFunction(object): + """Convience class wrapping a function in a :class:`PicklableModule`. + """ + + def __init__(self, module, name): + self._initialize(module, name) + + def _initialize(self, module, name): + self.module = module + self.name = name + self.func = module.mod_globals[name] + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def __getstate__(self): + return {"module": self.module, "name": self.name} + + def __setstate__(self, obj): + self._initialize(obj["module"], obj["name"]) + +# }}} + + # {{{ remove common indentation def remove_common_indentation(code, require_leading_newline=True): diff --git a/pytools/version.py b/pytools/version.py index f46939dfb75cb4e66e1c297fd8a5837099f2034f..a74b13294bb6727673d4e3e1683b90ef7ce1751a 100644 --- a/pytools/version.py +++ b/pytools/version.py @@ -1,3 +1,3 @@ -VERSION = (2018, 2) +VERSION = (2018, 3) VERSION_STATUS = "" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS diff --git a/test/test_py_codegen.py b/test/test_py_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..b418cc449cd929abd57119ac5bccfc12e1280d38 --- /dev/null +++ b/test/test_py_codegen.py @@ -0,0 +1,39 @@ +from __future__ import division, with_statement, absolute_import + +import pytest # noqa +import pytools +import pytools.py_codegen as codegen +import sys + + +def test_pickling_with_module_import(): + cg = codegen.PythonCodeGenerator() + cg("import pytools.py_codegen") + cg("import math as m") + + import pickle + mod = pickle.loads(pickle.dumps(cg.get_picklable_module())) + + assert mod.mod_globals["pytools"] is pytools + assert mod.mod_globals["pytools"].py_codegen is pytools.py_codegen + + import math + assert mod.mod_globals["m"] is math + + +def test_picklable_function(): + cg = codegen.PythonFunctionGenerator("f", args=()) + cg("return 1") + + import pickle + f = pickle.loads(pickle.dumps(cg.get_picklable_function())) + + assert f() == 1 + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from py.test import main + main([__file__])