diff --git a/examples/dagrt_fusion.py b/examples/dagrt_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..1099be1163f99f4a1a11b7a654cce6060984d9b7 --- /dev/null +++ b/examples/dagrt_fusion.py @@ -0,0 +1,170 @@ +import pyopencl as cl +from grudge.models.wave import StrongWaveOperator +from leap.rk import LSRK4Method +import numpy as np # noqa + +import dagrt.language as lang +import pymbolic.primitives as p +import grudge.symbolic.mappers as gmap +from pymbolic.mapper.evaluator import EvaluationMapper \ + as PymbolicEvaluationMapper + +from grudge import sym, bind, DGDiscretizationWithBoundaries + + +def topological_sort(stmts, root_deps): + id_to_stmt = {stmt.id: stmt for stmt in stmts} + + ordered_stmts = [] + satisfied = set() + + def satisfy_dep(name): + if name in satisfied: + return + + stmt = id_to_stmt[name] + for dep in stmt.depends_on: + satisfy_dep(dep) + ordered_stmts.append(stmt) + satisfied.add(name) + + for d in root_deps: + satisfy_dep(d) + + return ordered_stmts + + +# Use evaluation, not identity mappers to propagate symbolic vectors to +# outermost level. + +class DagrtToGrudgeRewriter(PymbolicEvaluationMapper): + def __init__(self, context): + self.context = context + + def map_variable(self, expr): + return self.context[expr.name] + + def map_call(self, expr): + raise ValueError("function call not expected") + + +class GrudgeArgSubstitutor(gmap.SymbolicEvaluator): + def __init__(self, args): + self.args = args + + def map_grudge_variable(self, expr): + if expr.name in self.args: + return self.args[expr.name] + else: + return super().map_variable(expr) + + +def transcribe_phase(dag, phase_name, rhs_name, sym_operator): + sym_operator = gmap.OperatorBinder()(sym_operator) + + phase = dag.phases[phase_name] + + ctx = { + "": sym.var("input_t", sym.DD_SCALAR), + "
": sym.var("input_dt", sym.DD_SCALAR), + "w": sym.make_sym_array("input_w", 3), + "

residual_w": sym.var("input_residual") + } + output_vars = [v for v in ctx.keys()] + yielded_states = [] + + from dagrt.codegen.transform import isolate_function_calls_in_phase + ordered_stmts = topological_sort( + isolate_function_calls_in_phase( + phase, + dag.get_stmt_id_generator(), + dag.get_var_name_generator()).statements, + phase.depends_on) + + for stmt in ordered_stmts: + if stmt.condition is not True: + raise NotImplementedError( + "non-True condition (in statement '%s') not supported" + % stmt.id) + + if isinstance(stmt, lang.Nop): + pass + elif isinstance(stmt, lang.AssignExpression): + if not isinstance(stmt.lhs, p.Variable): + raise NotImplementedError("lhs of statement %s is not a variable: %s" + % (stmt.id, stmt.lhs)) + ctx[stmt.lhs.name] = sym.cse( + DagrtToGrudgeRewriter(ctx)(stmt.rhs), + ( + stmt.lhs.name + .replace("<", "") + .replace(">", ""))) + elif isinstance(stmt, lang.AssignFunctionCall): + if stmt.function_id != rhs_name: + raise NotImplementedError( + "statement '%s' calls unsupported function '%s'" + % (stmt.id, stmt.function_id)) + if stmt.parameters: + raise NotImplementedError( + "statement '%s' calls function '%s' with positional arguments" + % (stmt.id, stmt.function_id)) + kwargs = {name: sym.cse(DagrtToGrudgeRewriter(ctx)(arg)) + for name, arg in stmt.kw_parameters.items()} + + if len(stmt.assignees) != 1: + raise NotImplementedError( + "statement '%s' calls function '%s' " + "with more than one LHS" + % (stmt.id, stmt.function_id)) + + assignee, = stmt.assignees + ctx[assignee] = GrudgeArgSubstitutor( + kwargs)(sym_operator) + elif isinstance(stmt, lang.YieldState): + d2g = DagrtToGrudgeRewriter(ctx) + yielded_states.append( + (stmt.time_id, d2g(stmt.time), stmt.component_id, + d2g(stmt.expression))) + else: + raise NotImplementedError("statement %s is of unsupported type ''%s'" + % (stmt.id, type(stmt).__name__)) + + return output_vars, [ctx[ov] for ov in output_vars], yielded_states + + +def main(): + dt_method = LSRK4Method(component_id="w") + dt_code = dt_method.generate() + + from meshmode.mesh import BTAG_ALL, BTAG_NONE + op = StrongWaveOperator(-0.1, 2, + dirichlet_tag=BTAG_NONE, + neumann_tag=BTAG_NONE, + radiation_tag=BTAG_ALL, + flux_type="upwind") + sym_op = op.sym_operator() + + ov, results, ys = transcribe_phase(dt_code, "primary", "w", sym_op) + for ov_i, res_i in zip(ov, results): + print(75*"#") + print(sym.pretty(res_i)) + print(f"{ov_i} IS THE ABOVE ^") + + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh( + a=(-0.5,)*2, + b=(0.5,)*2, + n=(16,)*2) + + cl_ctx = cl.create_some_context() + discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=3) + + from pytools.obj_array import join_fields + results = join_fields(results[0], results[1], *results[2], *results[3]) + bound_op = bind(discr, results) + + print(bound_op.eval_code) + + +if __name__ == "__main__": + main() diff --git a/grudge/execution.py b/grudge/execution.py index c20aa4bc6271f1b9a8f5620c8b3021c22563880c..89db0a1aaae757d8fb9f54089328757fa4a9897f 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -31,6 +31,7 @@ from pytools import memoize_in import grudge.symbolic.mappers as mappers from grudge import sym +from grudge.function_registry import base_function_registry import logging logger = logging.getLogger(__name__) @@ -50,6 +51,7 @@ class ExecutionMapper(mappers.Evaluator, super(ExecutionMapper, self).__init__(context) self.discrwb = bound_op.discrwb self.bound_op = bound_op + self.function_registry = bound_op.function_registry self.queue = queue # {{{ expression mappings ------------------------------------------------- @@ -95,45 +97,8 @@ class ExecutionMapper(mappers.Evaluator, return value def map_call(self, expr): - from pymbolic.primitives import Variable - assert isinstance(expr.function, Variable) - - # FIXME: Make a way to register functions - args = [self.rec(p) for p in expr.parameters] - from numbers import Number - representative_arg = args[0] - if ( - isinstance(representative_arg, Number) - or (isinstance(representative_arg, np.ndarray) - and representative_arg.shape == ())): - func = getattr(np, expr.function.name) - return func(*args) - - cached_name = "map_call_knl_" - - i = Variable("i") - func = Variable(expr.function.name) - if expr.function.name == "fabs": # FIXME - func = Variable("abs") - cached_name += "abs" - else: - cached_name += expr.function.name - - @memoize_in(self.bound_op, cached_name) - def knl(): - knl = lp.make_kernel( - "{[i]: 0<=i