diff --git a/examples/dagrt_fusion.py b/examples/dagrt_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..1099be1163f99f4a1a11b7a654cce6060984d9b7 --- /dev/null +++ b/examples/dagrt_fusion.py @@ -0,0 +1,170 @@ +import pyopencl as cl +from grudge.models.wave import StrongWaveOperator +from leap.rk import LSRK4Method +import numpy as np # noqa + +import dagrt.language as lang +import pymbolic.primitives as p +import grudge.symbolic.mappers as gmap +from pymbolic.mapper.evaluator import EvaluationMapper \ + as PymbolicEvaluationMapper + +from grudge import sym, bind, DGDiscretizationWithBoundaries + + +def topological_sort(stmts, root_deps): + id_to_stmt = {stmt.id: stmt for stmt in stmts} + + ordered_stmts = [] + satisfied = set() + + def satisfy_dep(name): + if name in satisfied: + return + + stmt = id_to_stmt[name] + for dep in stmt.depends_on: + satisfy_dep(dep) + ordered_stmts.append(stmt) + satisfied.add(name) + + for d in root_deps: + satisfy_dep(d) + + return ordered_stmts + + +# Use evaluation, not identity mappers to propagate symbolic vectors to +# outermost level. + +class DagrtToGrudgeRewriter(PymbolicEvaluationMapper): + def __init__(self, context): + self.context = context + + def map_variable(self, expr): + return self.context[expr.name] + + def map_call(self, expr): + raise ValueError("function call not expected") + + +class GrudgeArgSubstitutor(gmap.SymbolicEvaluator): + def __init__(self, args): + self.args = args + + def map_grudge_variable(self, expr): + if expr.name in self.args: + return self.args[expr.name] + else: + return super().map_variable(expr) + + +def transcribe_phase(dag, phase_name, rhs_name, sym_operator): + sym_operator = gmap.OperatorBinder()(sym_operator) + + phase = dag.phases[phase_name] + + ctx = { + "": sym.var("input_t", sym.DD_SCALAR), + "
": sym.var("input_dt", sym.DD_SCALAR), + "w": sym.make_sym_array("input_w", 3), + "

residual_w": sym.var("input_residual") + } + output_vars = [v for v in ctx.keys()] + yielded_states = [] + + from dagrt.codegen.transform import isolate_function_calls_in_phase + ordered_stmts = topological_sort( + isolate_function_calls_in_phase( + phase, + dag.get_stmt_id_generator(), + dag.get_var_name_generator()).statements, + phase.depends_on) + + for stmt in ordered_stmts: + if stmt.condition is not True: + raise NotImplementedError( + "non-True condition (in statement '%s') not supported" + % stmt.id) + + if isinstance(stmt, lang.Nop): + pass + elif isinstance(stmt, lang.AssignExpression): + if not isinstance(stmt.lhs, p.Variable): + raise NotImplementedError("lhs of statement %s is not a variable: %s" + % (stmt.id, stmt.lhs)) + ctx[stmt.lhs.name] = sym.cse( + DagrtToGrudgeRewriter(ctx)(stmt.rhs), + ( + stmt.lhs.name + .replace("<", "") + .replace(">", ""))) + elif isinstance(stmt, lang.AssignFunctionCall): + if stmt.function_id != rhs_name: + raise NotImplementedError( + "statement '%s' calls unsupported function '%s'" + % (stmt.id, stmt.function_id)) + if stmt.parameters: + raise NotImplementedError( + "statement '%s' calls function '%s' with positional arguments" + % (stmt.id, stmt.function_id)) + kwargs = {name: sym.cse(DagrtToGrudgeRewriter(ctx)(arg)) + for name, arg in stmt.kw_parameters.items()} + + if len(stmt.assignees) != 1: + raise NotImplementedError( + "statement '%s' calls function '%s' " + "with more than one LHS" + % (stmt.id, stmt.function_id)) + + assignee, = stmt.assignees + ctx[assignee] = GrudgeArgSubstitutor( + kwargs)(sym_operator) + elif isinstance(stmt, lang.YieldState): + d2g = DagrtToGrudgeRewriter(ctx) + yielded_states.append( + (stmt.time_id, d2g(stmt.time), stmt.component_id, + d2g(stmt.expression))) + else: + raise NotImplementedError("statement %s is of unsupported type ''%s'" + % (stmt.id, type(stmt).__name__)) + + return output_vars, [ctx[ov] for ov in output_vars], yielded_states + + +def main(): + dt_method = LSRK4Method(component_id="w") + dt_code = dt_method.generate() + + from meshmode.mesh import BTAG_ALL, BTAG_NONE + op = StrongWaveOperator(-0.1, 2, + dirichlet_tag=BTAG_NONE, + neumann_tag=BTAG_NONE, + radiation_tag=BTAG_ALL, + flux_type="upwind") + sym_op = op.sym_operator() + + ov, results, ys = transcribe_phase(dt_code, "primary", "w", sym_op) + for ov_i, res_i in zip(ov, results): + print(75*"#") + print(sym.pretty(res_i)) + print(f"{ov_i} IS THE ABOVE ^") + + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh( + a=(-0.5,)*2, + b=(0.5,)*2, + n=(16,)*2) + + cl_ctx = cl.create_some_context() + discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=3) + + from pytools.obj_array import join_fields + results = join_fields(results[0], results[1], *results[2], *results[3]) + bound_op = bind(discr, results) + + print(bound_op.eval_code) + + +if __name__ == "__main__": + main() diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index 53db42f18665fa37582c6e92afc5b214916f8e66..fc94e5eba0b80bf1baf123de12d0aac26d0d99ad 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -446,7 +446,8 @@ class Code(object): available_insns = [ (insn, insn.priority) for insn in self.instructions if insn not in done_insns - and all(dep.name in available_names + and all((dep.aggregate.name if isinstance(dep, Subscript) + else dep.name) in available_names for dep in insn.get_dependencies())] if not available_insns: @@ -454,7 +455,8 @@ class Code(object): from pytools import flatten discardable_vars = set(available_names) - set(flatten( - [dep.name for dep in insn.get_dependencies()] + [dep.aggregate.name if isinstance(dep, Subscript) else dep.name + for dep in insn.get_dependencies()] for insn in self.instructions if insn not in done_insns)) diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index c2fba3d6a757fa9ed80babf4f7315707bf4e7858..76864e7d1b4ae389ea3fc20e69e7d520908555eb 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -1269,6 +1269,33 @@ class FluxExchangeCollector(CSECachingMapperMixin, CollectorMixin, CombineMapper class Evaluator(pymbolic.mapper.evaluator.EvaluationMapper): pass + +class SymbolicEvaluator(pymbolic.mapper.evaluator.EvaluationMapper): + def map_operator_binding(self, expr, *args, **kwargs): + return expr.op(self.rec(expr.field, *args, **kwargs)) + + def map_node_coordinate_component(self, expr, *args, **kwargs): + return expr + + def map_call(self, expr, *args, **kwargs): + return type(expr)( + expr.function, + tuple(self.rec(child, *args, **kwargs) + for child in expr.parameters)) + + def map_call_with_kwargs(self, expr, *args, **kwargs): + return type(expr)( + expr.function, + tuple(self.rec(child, *args, **kwargs) + for child in expr.parameters), + dict( + (key, self.rec(val, *args, **kwargs)) + for key, val in six.iteritems(expr.kw_parameters)) + ) + + def map_common_subexpression(self, expr): + return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) + # }}}