diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index f95338d6e6156dba2732151fcdcc7adeb3004767..137b30ea8c671751346ec07c7d21252c948fe08c 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -883,7 +883,10 @@ class CodeGenerator(StructuredCodeGenerator): trace=False): """ - :arg function_registry: + :arg function_registry: An instance of + :class:`dagrt.function_registry.FunctionRegistry` + :arg module_preamble: A string to include at the beginning of the + emitted module :arg user_type_map: a map from user type names to :class:`FortranType` instances :arg call_before_state_update: The name of a function that should diff --git a/dagrt/codegen/python.py b/dagrt/codegen/python.py index 3f6ed3b4938e05d3793fb188948342084c914e00..cdac431058756199a3f4f6f7c26b058ff80c7cb2 100644 --- a/dagrt/codegen/python.py +++ b/dagrt/codegen/python.py @@ -185,11 +185,21 @@ class BareExpression(object): class CodeGenerator(StructuredCodeGenerator): - def __init__(self, class_name, function_registry=None): + def __init__(self, class_name, class_preamble=None, function_registry=None): + """ + :arg class_name: The name of the class to generate + :arg class_preamble: A string to include at the beginning of the + the class (in class scope) + :arg function_registry: An instance of + :class:`dagrt.function_registry.FunctionRegistry` + """ if function_registry is None: from dagrt.function_registry import base_function_registry function_registry = base_function_registry + from dagrt.codegen.utils import remove_common_indentation + self.class_preamble = remove_common_indentation(class_preamble) + self._class_name = class_name self._class_emitter = PythonClassEmitter(class_name) @@ -243,6 +253,13 @@ class CodeGenerator(StructuredCodeGenerator): self._emitter(wrapped_line) def begin_emit(self, dag): + if self.class_preamble: + emit = PythonEmitter() + for line in self.class_preamble: + emit(line) + emit("") + self._class_emitter.incorporate(emit) + self._emit_inner_classes() def _emit_inner_classes(self): diff --git a/test/test_codegen_python.py b/test/test_codegen_python.py index 237c736632c0a778ca8241ff127b0c22651a4cf1..ff0b29ebeee16474a177fcbf99af392b2faf969f 100755 --- a/test/test_codegen_python.py +++ b/test/test_codegen_python.py @@ -426,6 +426,44 @@ def test_svd(python_method_impl): assert la.norm(result) < 1e-10 +def test_class_preamble(): + from dagrt.language import CodeBuilder + + with CodeBuilder(label="primary") as cb: + cb.assign("", " +
") + cb.yield_state("f()", "f", 0, "final") + + code = DAGCode.create_with_steady_phase(cb.phase_dependencies, cb.statements) + + from dagrt.codegen import PythonCodeGenerator + import dagrt.function_registry as freg + + preamble = """ + @staticmethod + def f(): + return 1 + """ + + f = freg.Function( + identifier="f", + language_to_codegen={"python": lambda *args: "self.f()"}) + + generator = PythonCodeGenerator( + "PythonMethod", + class_preamble=preamble, + function_registry=freg.base_function_registry.register(f)) + + class_ = generator.get_class(code) + + method = class_(function_map={}) + method.set_up(t_start=0, dt_start=1, context={}) + + events = list(method.run(t_end=1)) + assert events + assert isinstance(events[0], class_.StateComputed) + assert events[0].state_component == 1 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])