diff --git a/pytools/py_codegen.py b/pytools/py_codegen.py index ee45a13b023148b996a65cc740914a1858ff6380..a63a928acf10593c6bfa02801e8a7693386f29e9 100644 --- a/pytools/py_codegen.py +++ b/pytools/py_codegen.py @@ -22,6 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six + # loosely based on # http://effbot.org/zone/python-code-generator.htm @@ -74,6 +76,18 @@ class PythonCodeGenerator(object): raise RuntimeError("internal error in python code generator") self.level -= 1 + def get_module(self, name="<generated code>"): + result_dict = {} + source_text = self.get() + exec(compile( + source_text.rstrip()+"\n", name, "exec"), + result_dict) + result_dict["_MODULE_SOURCE_CODE"] = source_text + return result_dict + + def get_picklable_module(self): + return PicklableModule(self.get_module()) + class PythonFunctionGenerator(PythonCodeGenerator): def __init__(self, name, args): @@ -84,15 +98,71 @@ class PythonFunctionGenerator(PythonCodeGenerator): self.indent() def get_function(self): - result_dict = {} - source_text = self.get() - exec(compile( - source_text.rstrip()+"\n", "<generated function %s>" - % self.name, "exec"), - result_dict) - func = result_dict[self.name] - result_dict["_MODULE_SOURCE_CODE"] = source_text - return func + return self.get_module()[self.name] + + +# {{{ pickling of binaries for generated code + +def _get_empty_module_dict(): + result_dict = {} + exec(compile("", "<generated function>", "exec"), result_dict) + return result_dict + + +_empty_module_dict = _get_empty_module_dict() + + +class PicklableModule(object): + def __init__(self, mod_globals): + self.mod_globals = mod_globals + + def __getstate__(self): + import marshal + + nondefault_globals = {} + functions = {} + + from types import FunctionType + for k, v in six.iteritems(self.mod_globals): + if isinstance(v, FunctionType): + functions[k] = ( + v.__name__, + marshal.dumps(v.__code__), + v.__defaults__) + + elif k not in _empty_module_dict: + nondefault_globals[k] = v + + import imp + return (0, imp.get_magic(), functions, nondefault_globals) + + def __setstate__(self, obj): + v = obj[0] + if v == 0: + magic, functions, nondefault_globals = obj[1:] + else: + raise ValueError("unknown version of PicklableGeneratedFunction") + + import imp + if magic != imp.get_magic(): + raise ValueError("cannot unpickle function binary: " + "incorrect magic value (got: %s, expected: %s)" + % (magic, imp.get_magic())) + + import marshal + + mod_globals = _empty_module_dict.copy() + mod_globals.update(nondefault_globals) + self.mod_globals = mod_globals + + from types import FunctionType + for k, v in six.iteritems(functions): + name, code_bytes, argdefs = v + f = FunctionType( + marshal.loads(code_bytes), mod_globals, argdefs=argdefs) + mod_globals[k] = f + +# }}} # {{{ remove common indentation @@ -122,3 +192,5 @@ def remove_common_indentation(code, require_leading_newline=True): return "\n".join(line[base_indent:] for line in lines) # }}} + +# vim: foldmethod=marker