From d34126aced13ae117ae5ee97062445053310f223 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 28 May 2021 12:27:57 -0500 Subject: [PATCH] Fix dagrt-fusion example for https://github.com/inducer/dagrt/pull/12 --- examples/old_symbolics/dagrt-fusion.py | 114 ++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 1 deletion(-) diff --git a/examples/old_symbolics/dagrt-fusion.py b/examples/old_symbolics/dagrt-fusion.py index ea6cb6bd..c431e21f 100755 --- a/examples/old_symbolics/dagrt-fusion.py +++ b/examples/old_symbolics/dagrt-fusion.py @@ -59,6 +59,7 @@ import pytest import dagrt.language as lang import pymbolic.primitives as p +from pymbolic.mapper import IdentityMapper from arraycontext import PyOpenCLArrayContext @@ -170,6 +171,118 @@ class GrudgeArgSubstitutor(gmap.SymbolicEvaluator): return super().map_variable(expr) +# {{{ isolate function calls + +# (copied from dagrt pre-https://github.com/inducer/dagrt/pull/12) + +class FunctionCallIsolator(IdentityMapper): + def __init__(self, new_statements, + stmt_id_gen, var_name_gen): + super().__init__() + self.new_statements = new_statements + self.stmt_id_gen = stmt_id_gen + self.var_name_gen = var_name_gen + + def isolate_call(self, expr, base_condition, base_deps, extra_deps, + super_method): + # FIXME: These aren't awesome identifiers. + tmp_var_name = self.var_name_gen("tmp") + + tmp_stmt_id = self.stmt_id_gen("tmp") + extra_deps.append(tmp_stmt_id) + + sub_extra_deps = [] + rec_result = super_method( + expr, base_deps, sub_extra_deps) + + from pymbolic.primitives import Call, CallWithKwargs + assert isinstance(rec_result, (Call, CallWithKwargs)) + + parameters = [] + kw_parameters = {} + + for par in rec_result.parameters: + parameters.append(par) + + if isinstance(rec_result, CallWithKwargs): + for par_name, par in rec_result.kw_parameters.items(): + kw_parameters[par_name] = par + + from dagrt.language import AssignFunctionCall + new_stmt = AssignFunctionCall( + assignees=(tmp_var_name,), + function_id=rec_result.function.name, + parameters=tuple(parameters), + kw_parameters=kw_parameters, + id=tmp_stmt_id, + condition=base_condition, + depends_on=base_deps | frozenset(sub_extra_deps)) + + self.new_statements.append(new_stmt) + + from pymbolic import var + return var(tmp_var_name) + + def map_call(self, expr, base_condition, base_deps, extra_deps): + return self.isolate_call( + expr, base_condition, base_deps, extra_deps, + super().map_call) + + def map_call_with_kwargs(self, expr, base_condition, base_deps, extra_deps): + return self.isolate_call( + expr, base_condition, base_deps, extra_deps, + super() + .map_call_with_kwargs) + + +def isolate_function_calls_in_phase(phase, stmt_id_gen, var_name_gen): + new_statements = [] + + fci = FunctionCallIsolator( + new_statements=new_statements, + stmt_id_gen=stmt_id_gen, + var_name_gen=var_name_gen) + + for stmt in sorted(phase.statements, key=lambda stmt: stmt.id): + new_deps = [] + + from dagrt.language import Assign + if isinstance(stmt, Assign): + assert not stmt.loops + new_statements.append( + stmt + .map_expressions( + lambda expr: fci( + expr, stmt.condition, stmt.depends_on, new_deps)) + .copy(depends_on=stmt.depends_on | frozenset(new_deps))) + from pymbolic.primitives import Call, CallWithKwargs + assert not isinstance(new_statements[-1].rhs, + (Call, CallWithKwargs)) + else: + new_statements.append(stmt) + + return phase.copy(statements=new_statements) + + +def isolate_function_calls(dag): + """ + :func:`isolate_function_arguments` should be + called before this. + """ + + stmt_id_gen = dag.get_stmt_id_generator() + var_name_gen = dag.get_var_name_generator() + + new_phases = {} + for phase_name, phase in dag.phases.items(): + new_phases[phase_name] = isolate_function_calls_in_phase( + phase, stmt_id_gen, var_name_gen) + + return dag.copy(phases=new_phases) + +# }}} + + def transcribe_phase(dag, field_var_name, field_components, phase_name, sym_operator): """Generate a Grudge operator for a Dagrt time integrator phase. @@ -203,7 +316,6 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name, output_vars = [v for v in ctx] yielded_states = [] - from dagrt.codegen.transform import isolate_function_calls_in_phase ordered_stmts = topological_sort( isolate_function_calls_in_phase( phase, -- GitLab