From b3a1364df4b143b5b15c36deb21ecca272643b2f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 4 Nov 2018 23:56:56 -0600 Subject: [PATCH 1/2] Factor out isolate_function_calls_in_phase --- dagrt/codegen/transform.py | 56 +++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/dagrt/codegen/transform.py b/dagrt/codegen/transform.py index ea8831e..3b0a17c 100644 --- a/dagrt/codegen/transform.py +++ b/dagrt/codegen/transform.py @@ -232,6 +232,35 @@ class FunctionCallIsolator(IdentityMapper): .map_call_with_kwargs) + +def isolate_function_calls_in_phase(phase, stmt_id_gen, var_name_gen): + new_statements = [] + + fci = FunctionCallIsolator( + new_statements=new_statements, + stmt_id_gen=stmt_id_gen, + var_name_gen=var_name_gen) + + for stmt in sorted(phase.statements, key=lambda stmt: stmt.id): + new_deps = [] + + from dagrt.language import AssignExpression + if isinstance(stmt, AssignExpression): + new_statements.append( + stmt + .map_expressions( + lambda expr: fci( + expr, stmt.condition, stmt.depends_on, new_deps)) + .copy(depends_on=stmt.depends_on | frozenset(new_deps))) + from pymbolic.primitives import Call, CallWithKwargs + assert not isinstance(new_statements[-1].rhs, + (Call, CallWithKwargs)) + else: + new_statements.append(stmt) + + return phase.copy(statements=new_statements) + + def isolate_function_calls(dag): """ :func:`isolate_function_arguments` should be @@ -243,31 +272,8 @@ def isolate_function_calls(dag): new_phases = {} for phase_name, phase in six.iteritems(dag.phases): - new_statements = [] - - fci = FunctionCallIsolator( - new_statements=new_statements, - stmt_id_gen=stmt_id_gen, - var_name_gen=var_name_gen) - - for stmt in sorted(phase.statements, key=lambda stmt: stmt.id): - new_deps = [] - - from dagrt.language import AssignExpression - if isinstance(stmt, AssignExpression): - new_statements.append( - stmt - .map_expressions( - lambda expr: fci( - expr, stmt.condition, stmt.depends_on, new_deps)) - .copy(depends_on=stmt.depends_on | frozenset(new_deps))) - from pymbolic.primitives import Call, CallWithKwargs - assert not isinstance(new_statements[-1].rhs, - (Call, CallWithKwargs)) - else: - new_statements.append(stmt) - - new_phases[phase_name] = phase.copy(statements=new_statements) + new_phases[phase_name] = isolate_function_calls_in_phase( + phase, stmt_id_gen, var_name_gen) return dag.copy(phases=new_phases) -- GitLab From 79a5d92ae41454da0edc8203e5f499b037113092 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 4 Nov 2018 23:59:09 -0600 Subject: [PATCH 2/2] Placate flake8 --- dagrt/codegen/transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dagrt/codegen/transform.py b/dagrt/codegen/transform.py index 3b0a17c..026832f 100644 --- a/dagrt/codegen/transform.py +++ b/dagrt/codegen/transform.py @@ -232,7 +232,6 @@ class FunctionCallIsolator(IdentityMapper): .map_call_with_kwargs) - def isolate_function_calls_in_phase(phase, stmt_id_gen, var_name_gen): new_statements = [] -- GitLab