diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py index f130c2486cdc271f5935da2aa79a8072b9b45644..3d30075742e7d6442f21d1e78adf4f6550b93bba 100644 --- a/examples/dagrt_fusion/fusion-study.py +++ b/examples/dagrt_fusion/fusion-study.py @@ -38,18 +38,19 @@ import pyopencl.array # noqa import pytest import dagrt.language as lang -import loopy as lp import pymbolic.primitives as p import grudge.symbolic.mappers as gmap import grudge.symbolic.operators as op from grudge.execution import ExecutionMapper from pymbolic.mapper.evaluator import EvaluationMapper \ as PymbolicEvaluationMapper -from pytools import memoize_in from grudge import sym, bind, DGDiscretizationWithBoundaries from leap.rk import LSRK4Method +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) + logging.basicConfig(level=logging.INFO) @@ -71,7 +72,7 @@ def topological_sort(stmts, root_deps): return stmt = id_to_stmt[name] - for dep in stmt.depends_on: + for dep in sorted(stmt.depends_on): satisfy_dep(dep) ordered_stmts.append(stmt) satisfied.add(name) @@ -333,7 +334,8 @@ class RK4TimeStepper(RK4TimeStepperBase): self.queue = queue self.grudge_bound_op = grudge_bound_op - self.set_up_stepper(discr, field_var_name, sym_rhs, num_fields, exec_mapper_factory) + self.set_up_stepper( + discr, field_var_name, sym_rhs, num_fields, exec_mapper_factory) self.component_getter = component_getter def _bound_op(self, t, *args, profile_data=None): @@ -411,8 +413,8 @@ def get_strong_wave_component(state_component): # {{{ equivalence check between fused and non-fused versions -def test_stepper_equivalence(order=4): - cl_ctx = cl.create_some_context() +def test_stepper_equivalence(ctx_factory, order=4): + cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) dims = 2 @@ -479,7 +481,8 @@ class MemOpCountingExecutionMapper(ExecutionMapper): args = [self.rec(p) for p in expr.parameters] return self.context[expr.function.name](*args, profile_data=profile_data) - def map_profiled_essentially_elementwise_linear(self, op, field_expr, profile_data): + def map_profiled_essentially_elementwise_linear(self, op, field_expr, + profile_data): result = getattr(self, op.mapper_method)(op, field_expr) if profile_data is not None: @@ -541,6 +544,9 @@ class MemOpCountingExecutionMapper(ExecutionMapper): if profile_data is not None and isinstance(val, pyopencl.array.Array): profile_data["bytes_read"] = ( profile_data.get("bytes_read", 0) + val.nbytes) + profile_data["bytes_read_within_assignments"] = ( + profile_data.get("bytes_read_within_assignments", 0) + + val.nbytes) discr = self.discrwb.discr_from_dd(kdescr.governing_dd) for name in kdescr.scalar_args(): @@ -563,6 +569,9 @@ class MemOpCountingExecutionMapper(ExecutionMapper): if profile_data is not None and isinstance(val, pyopencl.array.Array): profile_data["bytes_written"] = ( profile_data.get("bytes_written", 0) + val.nbytes) + profile_data["bytes_written_within_assignments"] = ( + profile_data.get("bytes_written_within_assignments", 0) + + val.nbytes) return list(result_dict.items()), [] @@ -598,7 +607,7 @@ class MemOpCountingExecutionMapper(ExecutionMapper): for _, value in assignments: profile_data["bytes_written"] = ( - profile_data.get("bytes_writte", 0) + value.nbytes) + profile_data.get("bytes_written", 0) + value.nbytes) return assignments, futures @@ -610,12 +619,13 @@ class MemOpCountingExecutionMapper(ExecutionMapper): # {{{ mem op counter check @pytest.mark.parametrize("use_fusion", (True, False)) -def test_stepper_mem_ops(use_fusion): - cl_ctx = cl.create_some_context() +def test_stepper_mem_ops(ctx_factory, use_fusion): + cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) dims = 2 - op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=3) + + op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=3) t_start = 0 dt = 0.04 @@ -625,11 +635,11 @@ def test_stepper_mem_ops(use_fusion): ic = join_fields(discr.zeros(queue), [discr.zeros(queue) for i in range(discr.dim)]) - bound_op = bind( - discr, op.sym_operator(), - exec_mapper_factory=MemOpCountingExecutionMapper) - if not use_fusion: + bound_op = bind( + discr, op.sym_operator(), + exec_mapper_factory=MemOpCountingExecutionMapper) + stepper = RK4TimeStepper( queue, discr, "w", bound_op, 1 + discr.dim, get_strong_wave_component, @@ -643,20 +653,151 @@ def test_stepper_mem_ops(use_fusion): step = 0 - norm = bind(discr, sym.norm(2, sym.var("u_ref") - sym.var("u"))) - nsteps = int(np.ceil((t_end + 1e-9) / dt)) for (_, _, profile_data) in stepper.run( ic, t_start, dt, t_end, return_profile_data=True): step += 1 - logger.debug("step %d/%d", step, nsteps) + logger.info("step %d/%d: %f", step, nsteps) + logger.info("fusion? %s", use_fusion) logger.info("bytes read: %d", profile_data["bytes_read"]) logger.info("bytes written: %d", profile_data["bytes_written"]) + logger.info("bytes total: %d", + profile_data["bytes_read"] + profile_data["bytes_written"]) + +# }}} + + +# {{{ paper outputs + +def get_example_stepper(queue, dims=2, order=3, use_fusion=True, + exec_mapper_factory=ExecutionMapper, + return_ic=False): + op, discr = get_strong_wave_op_with_discr(queue.context, dims=dims, order=3) + + if not use_fusion: + bound_op = bind( + discr, op.sym_operator(), + exec_mapper_factory=exec_mapper_factory) + + stepper = RK4TimeStepper( + queue, discr, "w", bound_op, 1 + discr.dim, + get_strong_wave_component, + exec_mapper_factory=exec_mapper_factory) + + else: + stepper = FusedRK4TimeStepper( + queue, discr, "w", op.sym_operator(), 1 + discr.dim, + get_strong_wave_component, + exec_mapper_factory=exec_mapper_factory) + + if return_ic: + from pytools.obj_array import join_fields + ic = join_fields(discr.zeros(queue), + [discr.zeros(queue) for i in range(discr.dim)]) + return stepper, ic + + return stepper + + +def statement_counts_table(): + cl_ctx = cl.create_some_context() + queue = cl.CommandQueue(cl_ctx) + + fused_stepper = get_example_stepper(queue, use_fusion=True) + stepper = get_example_stepper(queue, use_fusion=False) + + print(r"\begin{tabular}{lr}") + print(r"\toprule") + print(r"Operator & Grudge Node Count \\") + print(r"\midrule") + print( + r"Time integration (not fused) & %d \\" + % len(stepper.bound_op.eval_code.instructions)) + print( + r"Right-hand side (not fused) & %d \\" + % len(stepper.grudge_bound_op.eval_code.instructions)) + print( + r"Fused operator & %d \\" + % len(fused_stepper.bound_op.eval_code.instructions)) + print(r"\bottomrule") + print(r"\end{tabular}") + + +""" +def graphs(): + cl_ctx = cl.create_some_context() + queue = cl.CommandQueue(cl_ctx) + + fused_stepper = get_example_stepper(queue, use_fusion=True) + stepper = get_example_stepper(queue, use_fusion=False) + + from grudge.symbolic.compiler import dot_dataflow_graph +""" + + +def mem_ops_table(): + cl_ctx = cl.create_some_context() + queue = cl.CommandQueue(cl_ctx) + + fused_stepper = get_example_stepper( + queue, + use_fusion=True, + exec_mapper_factory=MemOpCountingExecutionMapper) + + stepper, ic = get_example_stepper( + queue, + use_fusion=False, + exec_mapper_factory=MemOpCountingExecutionMapper, + return_ic=True) + + t_start = 0 + dt = 0.02 + t_end = 0.02 + + for (_, _, profile_data) in stepper.run( + ic, t_start, dt, t_end, return_profile_data=True): + pass + + nonfused_bytes_read = profile_data["bytes_read"] + nonfused_bytes_written = profile_data["bytes_written"] + nonfused_bytes_total = nonfused_bytes_read + nonfused_bytes_written + + for (_, _, profile_data) in fused_stepper.run( + ic, t_start, dt, t_end, return_profile_data=True): + pass + + fused_bytes_read = profile_data["bytes_read"] + fused_bytes_written = profile_data["bytes_written"] + fused_bytes_total = fused_bytes_read + fused_bytes_written + + print(r"\begin{tabular}{lrrr}") + print(r"\toprule") + print(r"Operator & Bytes Read & Bytes Written & Total (\% of Baseline) \\") + print(r"\midrule") + print( + r"Baseline & \num{%d} & \num{%d} & \num{%d} (100) \\" + % ( + nonfused_bytes_read, + nonfused_bytes_written, + nonfused_bytes_total)) + print( + r"Fused & \num{%d} & \num{%d} & \num{%d} (%.1f) \\" + % ( + fused_bytes_read, + fused_bytes_written, + fused_bytes_total, + 100 * fused_bytes_total / nonfused_bytes_total)) + print(r"\bottomrule") + print(r"\end{tabular}") # }}} if __name__ == "__main__": #test_stepper_equivalence() - test_stepper_mem_ops() + #test_stepper_mem_ops(True) + #test_stepper_mem_ops(False) + #statement_counts_table() + #mem_ops_table() + pass