diff --git a/dagrt/codegen/transform.py b/dagrt/codegen/transform.py index ea8831e26fca4584860b0f6d75ff19cf47204f6d..026832ff3505cf90353d1128b8b6dbdab28aefb8 100644 --- a/dagrt/codegen/transform.py +++ b/dagrt/codegen/transform.py @@ -232,6 +232,34 @@ class FunctionCallIsolator(IdentityMapper): .map_call_with_kwargs) +def isolate_function_calls_in_phase(phase, stmt_id_gen, var_name_gen): + new_statements = [] + + fci = FunctionCallIsolator( + new_statements=new_statements, + stmt_id_gen=stmt_id_gen, + var_name_gen=var_name_gen) + + for stmt in sorted(phase.statements, key=lambda stmt: stmt.id): + new_deps = [] + + from dagrt.language import AssignExpression + if isinstance(stmt, AssignExpression): + new_statements.append( + stmt + .map_expressions( + lambda expr: fci( + expr, stmt.condition, stmt.depends_on, new_deps)) + .copy(depends_on=stmt.depends_on | frozenset(new_deps))) + from pymbolic.primitives import Call, CallWithKwargs + assert not isinstance(new_statements[-1].rhs, + (Call, CallWithKwargs)) + else: + new_statements.append(stmt) + + return phase.copy(statements=new_statements) + + def isolate_function_calls(dag): """ :func:`isolate_function_arguments` should be @@ -243,31 +271,8 @@ def isolate_function_calls(dag): new_phases = {} for phase_name, phase in six.iteritems(dag.phases): - new_statements = [] - - fci = FunctionCallIsolator( - new_statements=new_statements, - stmt_id_gen=stmt_id_gen, - var_name_gen=var_name_gen) - - for stmt in sorted(phase.statements, key=lambda stmt: stmt.id): - new_deps = [] - - from dagrt.language import AssignExpression - if isinstance(stmt, AssignExpression): - new_statements.append( - stmt - .map_expressions( - lambda expr: fci( - expr, stmt.condition, stmt.depends_on, new_deps)) - .copy(depends_on=stmt.depends_on | frozenset(new_deps))) - from pymbolic.primitives import Call, CallWithKwargs - assert not isinstance(new_statements[-1].rhs, - (Call, CallWithKwargs)) - else: - new_statements.append(stmt) - - new_phases[phase_name] = phase.copy(statements=new_statements) + new_phases[phase_name] = isolate_function_calls_in_phase( + phase, stmt_id_gen, var_name_gen) return dag.copy(phases=new_phases)