diff --git a/examples/dagrt_fusion/dagrt_fusion.py b/examples/dagrt_fusion/dagrt_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..1099be1163f99f4a1a11b7a654cce6060984d9b7 --- /dev/null +++ b/examples/dagrt_fusion/dagrt_fusion.py @@ -0,0 +1,170 @@ +import pyopencl as cl +from grudge.models.wave import StrongWaveOperator +from leap.rk import LSRK4Method +import numpy as np # noqa + +import dagrt.language as lang +import pymbolic.primitives as p +import grudge.symbolic.mappers as gmap +from pymbolic.mapper.evaluator import EvaluationMapper \ + as PymbolicEvaluationMapper + +from grudge import sym, bind, DGDiscretizationWithBoundaries + + +def topological_sort(stmts, root_deps): + id_to_stmt = {stmt.id: stmt for stmt in stmts} + + ordered_stmts = [] + satisfied = set() + + def satisfy_dep(name): + if name in satisfied: + return + + stmt = id_to_stmt[name] + for dep in stmt.depends_on: + satisfy_dep(dep) + ordered_stmts.append(stmt) + satisfied.add(name) + + for d in root_deps: + satisfy_dep(d) + + return ordered_stmts + + +# Use evaluation, not identity mappers to propagate symbolic vectors to +# outermost level. + +class DagrtToGrudgeRewriter(PymbolicEvaluationMapper): + def __init__(self, context): + self.context = context + + def map_variable(self, expr): + return self.context[expr.name] + + def map_call(self, expr): + raise ValueError("function call not expected") + + +class GrudgeArgSubstitutor(gmap.SymbolicEvaluator): + def __init__(self, args): + self.args = args + + def map_grudge_variable(self, expr): + if expr.name in self.args: + return self.args[expr.name] + else: + return super().map_variable(expr) + + +def transcribe_phase(dag, phase_name, rhs_name, sym_operator): + sym_operator = gmap.OperatorBinder()(sym_operator) + + phase = dag.phases[phase_name] + + ctx = { + "": sym.var("input_t", sym.DD_SCALAR), + "
": sym.var("input_dt", sym.DD_SCALAR), + "w": sym.make_sym_array("input_w", 3), + "

residual_w": sym.var("input_residual") + } + output_vars = [v for v in ctx.keys()] + yielded_states = [] + + from dagrt.codegen.transform import isolate_function_calls_in_phase + ordered_stmts = topological_sort( + isolate_function_calls_in_phase( + phase, + dag.get_stmt_id_generator(), + dag.get_var_name_generator()).statements, + phase.depends_on) + + for stmt in ordered_stmts: + if stmt.condition is not True: + raise NotImplementedError( + "non-True condition (in statement '%s') not supported" + % stmt.id) + + if isinstance(stmt, lang.Nop): + pass + elif isinstance(stmt, lang.AssignExpression): + if not isinstance(stmt.lhs, p.Variable): + raise NotImplementedError("lhs of statement %s is not a variable: %s" + % (stmt.id, stmt.lhs)) + ctx[stmt.lhs.name] = sym.cse( + DagrtToGrudgeRewriter(ctx)(stmt.rhs), + ( + stmt.lhs.name + .replace("<", "") + .replace(">", ""))) + elif isinstance(stmt, lang.AssignFunctionCall): + if stmt.function_id != rhs_name: + raise NotImplementedError( + "statement '%s' calls unsupported function '%s'" + % (stmt.id, stmt.function_id)) + if stmt.parameters: + raise NotImplementedError( + "statement '%s' calls function '%s' with positional arguments" + % (stmt.id, stmt.function_id)) + kwargs = {name: sym.cse(DagrtToGrudgeRewriter(ctx)(arg)) + for name, arg in stmt.kw_parameters.items()} + + if len(stmt.assignees) != 1: + raise NotImplementedError( + "statement '%s' calls function '%s' " + "with more than one LHS" + % (stmt.id, stmt.function_id)) + + assignee, = stmt.assignees + ctx[assignee] = GrudgeArgSubstitutor( + kwargs)(sym_operator) + elif isinstance(stmt, lang.YieldState): + d2g = DagrtToGrudgeRewriter(ctx) + yielded_states.append( + (stmt.time_id, d2g(stmt.time), stmt.component_id, + d2g(stmt.expression))) + else: + raise NotImplementedError("statement %s is of unsupported type ''%s'" + % (stmt.id, type(stmt).__name__)) + + return output_vars, [ctx[ov] for ov in output_vars], yielded_states + + +def main(): + dt_method = LSRK4Method(component_id="w") + dt_code = dt_method.generate() + + from meshmode.mesh import BTAG_ALL, BTAG_NONE + op = StrongWaveOperator(-0.1, 2, + dirichlet_tag=BTAG_NONE, + neumann_tag=BTAG_NONE, + radiation_tag=BTAG_ALL, + flux_type="upwind") + sym_op = op.sym_operator() + + ov, results, ys = transcribe_phase(dt_code, "primary", "w", sym_op) + for ov_i, res_i in zip(ov, results): + print(75*"#") + print(sym.pretty(res_i)) + print(f"{ov_i} IS THE ABOVE ^") + + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh( + a=(-0.5,)*2, + b=(0.5,)*2, + n=(16,)*2) + + cl_ctx = cl.create_some_context() + discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=3) + + from pytools.obj_array import join_fields + results = join_fields(results[0], results[1], *results[2], *results[3]) + bound_op = bind(discr, results) + + print(bound_op.eval_code) + + +if __name__ == "__main__": + main() diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py new file mode 100644 index 0000000000000000000000000000000000000000..efe86ec9178e7ea6b70c177a23c6eceb94a7315f --- /dev/null +++ b/examples/dagrt_fusion/fusion-study.py @@ -0,0 +1,355 @@ +from __future__ import division, print_function + +__copyright__ = "Copyright (C) 2015 Andreas Kloeckner" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import logging +import numpy as np +import pyopencl as cl + +import dagrt.language as lang +import pymbolic.primitives as p +import grudge.symbolic.mappers as gmap +from pymbolic.mapper.evaluator import EvaluationMapper \ + as PymbolicEvaluationMapper + +from grudge.shortcuts import set_up_rk4 +from grudge import sym, bind, DGDiscretizationWithBoundaries +from leap.rk import LSRK4Method + + +logging.basicConfig(level=logging.INFO) + +logger = logging.getLogger(__name__) + + +# {{{ topological sort + +def topological_sort(stmts, root_deps): + id_to_stmt = {stmt.id: stmt for stmt in stmts} + + ordered_stmts = [] + satisfied = set() + + def satisfy_dep(name): + if name in satisfied: + return + + stmt = id_to_stmt[name] + for dep in stmt.depends_on: + satisfy_dep(dep) + ordered_stmts.append(stmt) + satisfied.add(name) + + for d in root_deps: + satisfy_dep(d) + + return ordered_stmts + +# }}} + + +# Use evaluation, not identity mappers to propagate symbolic vectors to +# outermost level. + +class DagrtToGrudgeRewriter(PymbolicEvaluationMapper): + def __init__(self, context): + self.context = context + + def map_variable(self, expr): + return self.context[expr.name] + + def map_call(self, expr): + raise ValueError("function call not expected") + + +class GrudgeArgSubstitutor(gmap.SymbolicEvaluator): + def __init__(self, args): + self.args = args + + def map_grudge_variable(self, expr): + if expr.name in self.args: + return self.args[expr.name] + else: + return super().map_variable(expr) + + +def transcribe_phase(dag, field_var_name, field_components, phase_name, sym_operator): + sym_operator = gmap.OperatorBinder()(sym_operator) + + phase = dag.phases[phase_name] + + ctx = { + "": sym.var("input_t", sym.DD_SCALAR), + "

": sym.var("input_dt", sym.DD_SCALAR), + f"{field_var_name}": sym.make_sym_array( + f"input_{field_var_name}", field_components), + f"

residual": sym.make_sym_array("input_residual", field_components), + } + + rhs_name = f"{field_var_name}" + output_vars = [v for v in ctx] + yielded_states = [] + + from dagrt.codegen.transform import isolate_function_calls_in_phase + ordered_stmts = topological_sort( + isolate_function_calls_in_phase( + phase, + dag.get_stmt_id_generator(), + dag.get_var_name_generator()).statements, + phase.depends_on) + + for stmt in ordered_stmts: + if stmt.condition is not True: + raise NotImplementedError( + "non-True condition (in statement '%s') not supported" + % stmt.id) + + if isinstance(stmt, lang.Nop): + pass + elif isinstance(stmt, lang.AssignExpression): + if not isinstance(stmt.lhs, p.Variable): + raise NotImplementedError("lhs of statement %s is not a variable: %s" + % (stmt.id, stmt.lhs)) + ctx[stmt.lhs.name] = sym.cse( + DagrtToGrudgeRewriter(ctx)(stmt.rhs), + ( + stmt.lhs.name + .replace("<", "") + .replace(">", ""))) + elif isinstance(stmt, lang.AssignFunctionCall): + if stmt.function_id != rhs_name: + raise NotImplementedError( + "statement '%s' calls unsupported function '%s'" + % (stmt.id, stmt.function_id)) + if stmt.parameters: + raise NotImplementedError( + "statement '%s' calls function '%s' with positional arguments" + % (stmt.id, stmt.function_id)) + kwargs = {name: sym.cse(DagrtToGrudgeRewriter(ctx)(arg)) + for name, arg in stmt.kw_parameters.items()} + + if len(stmt.assignees) != 1: + raise NotImplementedError( + "statement '%s' calls function '%s' " + "with more than one LHS" + % (stmt.id, stmt.function_id)) + + assignee, = stmt.assignees + ctx[assignee] = GrudgeArgSubstitutor( + kwargs)(sym_operator) + elif isinstance(stmt, lang.YieldState): + d2g = DagrtToGrudgeRewriter(ctx) + yielded_states.append( + (stmt.time_id, d2g(stmt.time), stmt.component_id, + d2g(stmt.expression))) + else: + raise NotImplementedError("statement %s is of unsupported type ''%s'" + % (stmt.id, type(stmt).__name__)) + + return output_vars, [ctx[ov] for ov in output_vars], yielded_states + + +def get_strong_wave_op_with_discr(cl_ctx, dims=3, order=4): + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dims, + b=(0.5,)*dims, + n=(16,)*dims) + + logger.info("%d elements" % mesh.nelements) + + discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=order) + + source_center = np.array([0.1, 0.22, 0.33])[:dims] + source_width = 0.05 + source_omega = 3 + + sym_x = sym.nodes(mesh.dim) + sym_source_center_dist = sym_x - source_center + sym_t = sym.ScalarVariable("t") + + from grudge.models.wave import StrongWaveOperator + from meshmode.mesh import BTAG_ALL, BTAG_NONE + op = StrongWaveOperator(-0.1, dims, + source_f=( + sym.sin(source_omega*sym_t) + * sym.exp( + -np.dot(sym_source_center_dist, sym_source_center_dist) + / source_width**2)), + dirichlet_tag=BTAG_NONE, + neumann_tag=BTAG_NONE, + radiation_tag=BTAG_ALL, + flux_type="upwind") + + op.check_bc_coverage(mesh) + + return (op, discr) + + +class RK4TimeStepper(object): + + def __init__(self, field_var_name, dt, fields, rhs, component_getter, t_start=0): + self.component_getter = component_getter + self.field_var_name = field_var_name + self.stepper = set_up_rk4(field_var_name, dt, fields, rhs, t_start) + + def run(self, t_end): + for event in self.stepper.run(t_end=t_end): + if isinstance(event, self.stepper.StateComputed): + assert event.component_id == self.field_var_name + yield ( + event.t, + self.component_getter(event.state_component)) + + +class FusedRK4TimeStepper(object): + + def __init__(self, queue, field_var_name, dt, fields, sym_rhs, discr, component_getter, t_start=0): + self.t_start = t_start + self.dt = dt + dt_method = LSRK4Method(component_id=field_var_name) + dt_code = dt_method.generate() + + from pytools.obj_array import join_fields + + # Flatten fields. + flattened_fields = [] + for field in fields: + if isinstance(field, list): + flattened_fields.extend(field) + else: + flattened_fields.append(field) + flattened_fields = join_fields(*flattened_fields) + del fields + + output_vars, results, yielded_states = transcribe_phase( + dt_code, field_var_name, len(flattened_fields), + "primary", sym_rhs) + + output_t = results[0] + output_dt = results[1] + output_states = results[2] + output_residuals = results[3] + + assert len(output_states) == len(flattened_fields) + assert len(output_states) == len(output_residuals) + + flattened_results = join_fields(output_t, output_dt, *output_states) + self.bound_op = bind(discr, flattened_results) + self.queue = queue + + self.state_name = f"input_{field_var_name}" + + self.initial_context = { + "input_t": t_start, + "input_dt": dt, + self.state_name: flattened_fields, + "input_residual": flattened_fields, + } + + self.component_getter = component_getter + + def run(self, t_end): + t = self.t_start + context = self.initial_context.copy() + + while t <= t_end: + results = self.bound_op(self.queue, **context) + t = results[0] + context["input_t"] = t + context["input_dt"] = results[1] + output_states = results[2:] + context[self.state_name] = output_states + yield (t, self.component_getter(output_states)) + + +def get_strong_wave_component(state_component): + return (state_component[0], state_component[1:]) + + +# {{{ equivalence check + +def test_stepper_equivalence(order=4): + cl_ctx = cl.create_some_context() + queue = cl.CommandQueue(cl_ctx) + + dims = 2 + + op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=order) + + if dims == 2: + dt = 0.04 + elif dims == 3: + dt = 0.02 + + from pytools.obj_array import join_fields + fields = join_fields(discr.zeros(queue), + [discr.zeros(queue) for i in range(discr.dim)]) + + bound_op = bind(discr, op.sym_operator()) + def rhs(t, w): + return bound_op(queue, t=t, w=w) + + stepper = RK4TimeStepper( + "w", dt, fields, rhs, get_strong_wave_component) + + fused_stepper = FusedRK4TimeStepper( + queue, "w", dt, fields, op.sym_operator(), + discr, get_strong_wave_component) + + final_t = 10 + nsteps = int(final_t/dt) + print("dt=%g nsteps=%d" % (dt, nsteps)) + + step = 0 + + norm = bind(discr, sym.norm(2, sym.var("u_ref") - sym.var("u"))) + + import time + s = time.time() + for i, _ in enumerate(stepper.run(t_end=1)): + print(i) + print("stepper", time.time() - s) + + import time + s = time.time() + for i, _ in enumerate(fused_stepper.run(t_end=1)): + print(i) + print("fused stepper", time.time() - s) + + """ + fused_steps = fused_stepper.run(t_end=final_t) + + for t_ref, (u_ref, v_ref) in fused_stepper.run(t_end=final_t): + step += 1 + t, (u, v) = next(fused_steps) + assert t == t_ref, step + assert norm(queue, u=u, u_ref=u_ref) <= 1e-13, step + """ + +# }}} + + +if __name__ == "__main__": + test_stepper_equivalence() diff --git a/examples/wave/wave-min.py b/examples/wave/wave-min.py index d7ebffd33a7732c621d159622804fed064f365b3..1b7ec765f42dc8839d31a31b98736d1ea40da12c 100644 --- a/examples/wave/wave-min.py +++ b/examples/wave/wave-min.py @@ -86,7 +86,10 @@ def main(write_output=True, order=4): bound_op = bind(discr, op.sym_operator()) def rhs(t, w): - return bound_op(queue, t=t, w=w) + result, op_counts = bound_op(queue, _track_memory_traffic=True, t=t, w=w) + print("Bytes read: %d" % sum(op.bytes_read for op in op_counts.values())) + print("Bytes written: %d" % sum(op.bytes_written for op in op_counts.values())) + return result dt_stepper = set_up_rk4("w", dt, fields, rhs) diff --git a/grudge/execution.py b/grudge/execution.py index c20aa4bc6271f1b9a8f5620c8b3021c22563880c..3f7fd515940d5c002463776939e35585662541b1 100644 --- a/grudge/execution.py +++ b/grudge/execution.py @@ -43,14 +43,53 @@ MPI_TAG_SEND_TAGS = 1729 # {{{ exec mapper +from collections import namedtuple +MemoryTrafficResult = namedtuple("MemoryTrafficResult", "bytes_read, bytes_written") + + +def should_track_value(val): + return isinstance(val, (np.ndarray, pyopencl.array.Array)) + + +def get_size(val): + if isinstance(val, np.ndarray) and val.dtype == np.object: + result = 0 + for item in val: + result += item.nbytes + return result + return val.nbytes + + +def track_memory_traffic(f): + def run_with_memory_tracking(self, insn): + if self._track_memory_traffic: + bytes_read = self.track_memory_reads_for_insn(insn) + + result = f(self, insn) + + if self._track_memory_traffic: + bytes_written = 0 + for _, value in result[0]: + if should_track_value(value): + bytes_written += get_size(value) + + result += (MemoryTrafficResult(bytes_read, bytes_written),) + + return result + + import functools + return functools.update_wrapper(run_with_memory_tracking, f) + + class ExecutionMapper(mappers.Evaluator, mappers.BoundOpMapperMixin, mappers.LocalOpReducerMixin): - def __init__(self, queue, context, bound_op): + def __init__(self, queue, context, bound_op, _track_memory_traffic=False): super(ExecutionMapper, self).__init__(context) self.discrwb = bound_op.discrwb self.bound_op = bound_op self.queue = queue + self._track_memory_traffic = _track_memory_traffic # {{{ expression mappings ------------------------------------------------- @@ -320,8 +359,23 @@ class ExecutionMapper(mappers.Evaluator, # }}} + # {{{ memory traffic + + def track_memory_reads_for_insn(self, insn): + from pymbolic.primitives import Variable + result = 0 + for var in insn.get_dependencies(): + if isinstance(var, Variable): + var = var.name + if should_track_value(self.context[var]): + result += get_size(self.context[var]) + return result + + # }}} + # {{{ instruction execution functions + @track_memory_traffic def map_insn_rank_data_swap(self, insn): local_data = self.rec(insn.field).get(self.queue) comm = self.discrwb.mpi_communicator @@ -337,6 +391,7 @@ class ExecutionMapper(mappers.Evaluator, MPIRecvFuture(recv_req, insn.name, remote_data_host, self.queue), MPISendFuture(send_req)] + @track_memory_traffic def map_insn_loopy_kernel(self, insn): kwargs = {} kdescr = insn.kernel_descriptor @@ -360,10 +415,12 @@ class ExecutionMapper(mappers.Evaluator, evt, result_dict = kdescr.loopy_kernel(self.queue, **kwargs) return list(result_dict.items()), [] + @track_memory_traffic def map_insn_assign(self, insn): return [(name, self.rec(expr)) for name, expr in zip(insn.names, insn.exprs)], [] + @track_memory_traffic def map_insn_assign_to_discr_scoped(self, insn): assignments = [] for name, expr in zip(insn.names, insn.exprs): @@ -373,10 +430,12 @@ class ExecutionMapper(mappers.Evaluator, return assignments, [] + @track_memory_traffic def map_insn_assign_from_discr_scoped(self, insn): return [(insn.name, self.discrwb._discr_scoped_subexpr_name_to_value[insn.name])], [] + @track_memory_traffic def map_insn_diff_batch_assign(self, insn): field = self.rec(insn.field) repr_op = insn.operators[0] @@ -493,7 +552,8 @@ class BoundOperator(object): + sep + str(self.eval_code)) - def __call__(self, queue, profile_data=None, log_quantities=None, **context): + def __call__(self, queue, profile_data=None, log_quantities=None, + _track_memory_traffic=False, **context): import pyopencl.array as cl_array def replace_queue(a): @@ -522,9 +582,11 @@ class BoundOperator(object): new_context[name] = with_object_array_or_scalar(replace_queue, var) return self.eval_code.execute( - ExecutionMapper(queue, new_context, self), + ExecutionMapper(queue, new_context, self, + _track_memory_traffic=_track_memory_traffic), profile_data=profile_data, - log_quantities=log_quantities) + log_quantities=log_quantities, + _track_memory_traffic=_track_memory_traffic) # }}} diff --git a/grudge/shortcuts.py b/grudge/shortcuts.py index bb12689bf65208c3a48bb299f2a44c6651c5b58b..835dd115d649b3d39d868da0354f2eb44d121464 100644 --- a/grudge/shortcuts.py +++ b/grudge/shortcuts.py @@ -27,6 +27,133 @@ THE SOFTWARE. import pyopencl as cl +import dagrt.language as lang +import pymbolic.primitives as p +from grudge import sym +import grudge.symbolic.mappers as gmap +from pymbolic.mapper.evaluator import EvaluationMapper \ + as PymbolicEvaluationMapper + + +def topological_sort(stmts, root_deps): + id_to_stmt = {stmt.id: stmt for stmt in stmts} + + ordered_stmts = [] + satisfied = set() + + def satisfy_dep(name): + if name in satisfied: + return + + stmt = id_to_stmt[name] + for dep in stmt.depends_on: + satisfy_dep(dep) + ordered_stmts.append(stmt) + satisfied.add(name) + + for d in root_deps: + satisfy_dep(d) + + return ordered_stmts + + +# Use evaluation, not identity mappers to propagate symbolic vectors to +# outermost level. + +class DagrtToGrudgeRewriter(PymbolicEvaluationMapper): + def __init__(self, context): + self.context = context + + def map_variable(self, expr): + return self.context[expr.name] + + def map_call(self, expr): + raise ValueError("function call not expected") + + +class GrudgeArgSubstitutor(gmap.SymbolicEvaluator): + def __init__(self, args): + self.args = args + + def map_grudge_variable(self, expr): + if expr.name in self.args: + return self.args[expr.name] + else: + return super().map_variable(expr) + + +def transcribe_dagrt_phase(dag, phase_name, rhs_name, sym_operator): + sym_operator = gmap.OperatorBinder()(sym_operator) + + phase = dag.phases[phase_name] + + ctx = { + "": sym.var("input_t", sym.DD_SCALAR), + "

": sym.var("input_dt", sym.DD_SCALAR), + "w": sym.make_sym_array("input_w", 3), + "

residual_w": sym.var("input_residual") + } + output_vars = [v for v in ctx.keys()] + yielded_states = [] + + from dagrt.codegen.transform import isolate_function_calls_in_phase + ordered_stmts = topological_sort( + isolate_function_calls_in_phase( + phase, + dag.get_stmt_id_generator(), + dag.get_var_name_generator()).statements, + phase.depends_on) + + for stmt in ordered_stmts: + if stmt.condition is not True: + raise NotImplementedError( + "non-True condition (in statement '%s') not supported" + % stmt.id) + + if isinstance(stmt, lang.Nop): + pass + elif isinstance(stmt, lang.AssignExpression): + if not isinstance(stmt.lhs, p.Variable): + raise NotImplementedError("lhs of statement %s is not a variable: %s" + % (stmt.id, stmt.lhs)) + ctx[stmt.lhs.name] = sym.cse( + DagrtToGrudgeRewriter(ctx)(stmt.rhs), + ( + stmt.lhs.name + .replace("<", "") + .replace(">", ""))) + elif isinstance(stmt, lang.AssignFunctionCall): + if stmt.function_id != rhs_name: + raise NotImplementedError( + "statement '%s' calls unsupported function '%s'" + % (stmt.id, stmt.function_id)) + if stmt.parameters: + raise NotImplementedError( + "statement '%s' calls function '%s' with positional arguments" + % (stmt.id, stmt.function_id)) + kwargs = {name: sym.cse(DagrtToGrudgeRewriter(ctx)(arg)) + for name, arg in stmt.kw_parameters.items()} + + if len(stmt.assignees) != 1: + raise NotImplementedError( + "statement '%s' calls function '%s' " + "with more than one LHS" + % (stmt.id, stmt.function_id)) + + assignee, = stmt.assignees + ctx[assignee] = GrudgeArgSubstitutor( + kwargs)(sym_operator) + elif isinstance(stmt, lang.YieldState): + d2g = DagrtToGrudgeRewriter(ctx) + yielded_states.append( + (stmt.time_id, d2g(stmt.time), stmt.component_id, + d2g(stmt.expression))) + else: + raise NotImplementedError("statement %s is of unsupported type ''%s'" + % (stmt.id, type(stmt).__name__)) + + return output_vars, [ctx[ov] for ov in output_vars], yielded_states + def set_up_rk4(field_var_name, dt, fields, rhs, t_start=0): from leap.rk import LSRK4Method @@ -52,7 +179,6 @@ def make_visualizer(discrwb, vis_order): def make_boundary_visualizer(discrwb, vis_order): from meshmode.discretization.visualization import make_visualizer - from grudge import sym with cl.CommandQueue(discrwb.cl_context) as queue: return make_visualizer( queue, discrwb.discr_from_dd(sym.BTAG_ALL), diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py index b27d55855237b5e0200b094c2aa164e192022f0b..919f6b6067dd76e1e81800587db504323ef79150 100644 --- a/grudge/symbolic/compiler.py +++ b/grudge/symbolic/compiler.py @@ -192,7 +192,6 @@ class Assign(AssignBase): from pymbolic.primitives import Variable deps -= set(Variable(name) for name in self.names) - if not each_vector: self._dependencies = deps @@ -447,7 +446,8 @@ class Code(object): available_insns = [ (insn, insn.priority) for insn in self.instructions if insn not in done_insns - and all(dep.name in available_names + and all((dep.aggregate.name if isinstance(dep, Subscript) + else dep.name) in available_names for dep in insn.get_dependencies())] if not available_insns: @@ -455,7 +455,8 @@ class Code(object): from pytools import flatten discardable_vars = set(available_names) - set(flatten( - [dep.name for dep in insn.get_dependencies()] + [dep.aggregate.name if isinstance(dep, Subscript) else dep.name + for dep in insn.get_dependencies()] for insn in self.instructions if insn not in done_insns)) @@ -482,7 +483,7 @@ class Code(object): return argmax2(available_insns), discardable_vars def execute(self, exec_mapper, pre_assign_check=None, profile_data=None, - log_quantities=None): + log_quantities=None, _track_memory_traffic=False): if profile_data is not None: from time import time start_time = time() @@ -493,6 +494,8 @@ class Code(object): profile_data['total_time'] = 0 if log_quantities is not None: exec_sub_timer = log_quantities["exec_timer"].start_sub_timer() + if _track_memory_traffic is not None: + memory_traffic = {} context = exec_mapper.context futures = [] @@ -523,7 +526,14 @@ class Code(object): log_quantities["rank_data_swap_timer"], log_quantities["rank_data_swap_counter"]) - assignments, new_futures = mapper_method(insn) + if not _track_memory_traffic: + assignments, new_futures = mapper_method(insn) + else: + assignments,\ + new_futures,\ + memory_traffic_result = mapper_method(insn) + assert insn not in memory_traffic + memory_traffic[insn] = memory_traffic_result for target, value in assignments: if pre_assign_check is not None: @@ -589,12 +599,19 @@ class Code(object): if log_quantities is not None: exec_sub_timer.stop().submit() - from pytools.obj_array import with_object_array_or_scalar if profile_data is not None: profile_data['total_time'] += time() - start_time - return (with_object_array_or_scalar(exec_mapper, self.result), - profile_data) - return with_object_array_or_scalar(exec_mapper, self.result) + + from pytools.obj_array import with_object_array_or_scalar + ret_val = (with_object_array_or_scalar(exec_mapper, self.result),) + + if profile_data is not None: + ret_val += (profile_data,) + + if _track_memory_traffic: + ret_val += (memory_traffic,) + + return ret_val[0] if len(ret_val) == 1 else ret_val # }}} diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index d52f7ac5bf99adbb16a618a8c3cb44fb8c65c515..0cea10df03bec9dfda582aa30540c6e7b3c7e87f 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -1269,6 +1269,33 @@ class FluxExchangeCollector(CSECachingMapperMixin, CollectorMixin, CombineMapper class Evaluator(pymbolic.mapper.evaluator.EvaluationMapper): pass + +class SymbolicEvaluator(pymbolic.mapper.evaluator.EvaluationMapper): + def map_operator_binding(self, expr, *args, **kwargs): + return expr.op(self.rec(expr.field, *args, **kwargs)) + + def map_node_coordinate_component(self, expr, *args, **kwargs): + return expr + + def map_call(self, expr, *args, **kwargs): + return type(expr)( + expr.function, + tuple(self.rec(child, *args, **kwargs) + for child in expr.parameters)) + + def map_call_with_kwargs(self, expr, *args, **kwargs): + return type(expr)( + expr.function, + tuple(self.rec(child, *args, **kwargs) + for child in expr.parameters), + dict( + (key, self.rec(val, *args, **kwargs)) + for key, val in six.iteritems(expr.kw_parameters)) + ) + + def map_common_subexpression(self, expr): + return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) + # }}}