diff --git a/dagrt/codegen/analysis.py b/dagrt/codegen/analysis.py index 2d4e40c153d937b9414dfcd9b577897642f84aaf..8d61bc122e5c7394e7852c1e1b4e97fc38e057fe 100644 --- a/dagrt/codegen/analysis.py +++ b/dagrt/codegen/analysis.py @@ -247,4 +247,34 @@ def collect_ode_component_names_from_dag(dag): # }}} + +# {{{ variable to last dependent statement mapping + +def var_to_last_dependent_statement_mapping(names, statement_lists): + """For each function in names, return a mapping of each variable to the + latest statement in statement_lists at which that variable is used. + This is used for intermediate deallocation of variables that no longer + need to be read or written. + + :arg names: a list of function names in the ast. + :arg statement_lists: a set of topological orderings of the statements + in each function. + """ + + tbl = {} + + for name, stmts in zip(names, statement_lists): + + for statement in stmts: + # Associate latest statement in this phase at which + # a given variable is used + read_and_written = statement.get_read_variables().union( + statement.get_written_variables()) + for variable in read_and_written: + tbl[variable, name] = statement.id + + return tbl + +# }}} + # vim: foldmethod=marker diff --git a/dagrt/codegen/fortran.py b/dagrt/codegen/fortran.py index 8f55710f6b3f1969358b0243985ca80f7dab9196..535779cd3043cc0575bfd8d1f24c7b5635eb7881 100644 --- a/dagrt/codegen/fortran.py +++ b/dagrt/codegen/fortran.py @@ -1046,7 +1046,8 @@ class CodeGenerator(StructuredCodeGenerator): # {{{ produce function name / function AST pairs - from dagrt.codegen.dag_ast import create_ast_from_phase + from dagrt.codegen.dag_ast import ( + create_ast_from_phase, get_statements_in_ast) from collections import namedtuple NameASTPair = namedtuple("NameASTPair", "name, ast") # noqa @@ -1064,9 +1065,16 @@ class CodeGenerator(StructuredCodeGenerator): [fd.name for fd in fdescrs], [fd.ast for fd in fdescrs]) - from dagrt.codegen.analysis import collect_ode_component_names_from_dag + from dagrt.codegen.analysis import ( + collect_ode_component_names_from_dag, + var_to_last_dependent_statement_mapping) + component_ids = collect_ode_component_names_from_dag(dag) + self.last_used_stmt_table = var_to_last_dependent_statement_mapping( + [fd.name for fd in fdescrs], + [get_statements_in_ast(fd.ast) for fd in fdescrs]) + if not component_ids <= set(self.user_type_map): raise RuntimeError("User type missing from user type map: %r" % (component_ids - set(self.user_type_map))) @@ -1144,7 +1152,8 @@ class CodeGenerator(StructuredCodeGenerator): self.current_function, {}) for identifier, sym_kind in sorted(six.iteritems(sym_table)): - self.emit_variable_deinit(identifier, sym_kind) + if (identifier, self.current_function) not in self.last_used_stmt_table: + self.emit_variable_deinit(identifier, sym_kind) # }}} @@ -2115,6 +2124,8 @@ class CodeGenerator(StructuredCodeGenerator): for ident, start, stop in inst.loops[::-1]: self.emitter.__exit__(None, None, None) + self.emit_deinit_for_last_usage_of_vars(inst) + assert start_em is self.emitter # }}} @@ -2192,6 +2203,8 @@ class CodeGenerator(StructuredCodeGenerator): + assignee_fortran_names ))) + self.emit_deinit_for_last_usage_of_vars(inst) + # }}} def emit_return(self): @@ -2231,6 +2244,29 @@ class CodeGenerator(StructuredCodeGenerator): (var(self.component_name_to_component_sym( inst.component_id)),))) + def emit_deinit_for_last_usage_of_vars(self, inst): + """Check if, for any of the variables in instruction *inst*, + *inst* contains the last use of that variable in the + :attr:`current_function`. If so, emit code to deallocate that variable. + """ + from dagrt.utils import is_state_variable + + read_and_written = inst.get_read_variables() | inst.get_written_variables() + + for variable in read_and_written: + # FIXME: This can fail for args of state update notification, + # hence the try/catch. + try: + var_kind = self.sym_kind_table.get( + self.current_function, variable) + except KeyError: + continue + + last_used_stmt_id = self.last_used_stmt_table[ + variable, self.current_function] + if inst.id == last_used_stmt_id and not is_state_variable(variable): + self.emit_variable_deinit(variable, var_kind) + def emit_inst_Raise(self, inst): # FIXME: Reenable emitting full error message # TBD: Quoting of quotes, extra-long lines