Skip to content
Snippets Groups Projects
dagrt-fusion.py 42.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • Matt Wala's avatar
    Matt Wala committed
        nsteps = int(np.ceil((t_end + 1e-9) / dt))
        for (_, _, profile_data) in stepper.run(
                ic, t_start, dt, t_end, return_profile_data=True):
            step += 1
            tn = time.time()
            logger.info("step %d/%d: %f", step, nsteps, tn - t)
            t = tn
    
        logger.info("fusion? %s", use_fusion)
    
    Matt Wala's avatar
    Matt Wala committed
        for key, value in profile_data.items():
            if isinstance(value, TimingFutureList):
                print(key, value.elapsed())
    
    Matt Wala's avatar
    Matt Wala committed
    # {{{ paper outputs
    
    def get_example_stepper(queue, dims=2, order=3, use_fusion=True,
                            exec_mapper_factory=ExecutionMapper,
                            return_ic=False):
    
        sym_operator, discr = get_strong_wave_op_with_discr_direct(
                queue.context, dims=dims, order=3)
    
    Matt Wala's avatar
    Matt Wala committed
    
        if not use_fusion:
            bound_op = bind(
    
    Matt Wala's avatar
    Matt Wala committed
                    exec_mapper_factory=exec_mapper_factory)
    
            stepper = RK4TimeStepper(
                    queue, discr, "w", bound_op, 1 + discr.dim,
                    get_strong_wave_component,
                    exec_mapper_factory=exec_mapper_factory)
    
        else:
            stepper = FusedRK4TimeStepper(
    
                    queue, discr, "w", sym_operator, 1 + discr.dim,
    
    Matt Wala's avatar
    Matt Wala committed
                    get_strong_wave_component,
                    exec_mapper_factory=exec_mapper_factory)
    
        if return_ic:
            from pytools.obj_array import join_fields
            ic = join_fields(discr.zeros(queue),
                    [discr.zeros(queue) for i in range(discr.dim)])
            return stepper, ic
    
        return stepper
    
    
    
    def latex_table(table_format, header, rows):
        result = []
        _ = result.append
        _(rf"\begin{{tabular}}{{{table_format}}}")
        _(r"\toprule")
        _(" & ".join(rf"\multicolumn{{1}}{{c}}{{{item}}}" for item in header) + r" \\")
        _(r"\midrule")
        for row in rows:
            _(" & ".join(row) + r" \\")
        _(r"\bottomrule")
        _(r"\end{tabular}")
        return "\n".join(result)
    
    
    
    def ascii_table(table_format, header, rows):
        from pytools import Table
        table = Table()
        table.add_row(header)
    
        for input_row in rows:
            row = []
            for item in input_row:
                if item.startswith(r"\num{"):
                    # Strip \num{...} formatting
                    row.append(item[5:-1])
                else:
                    row.append(item)
            table.add_row(row)
    
        return str(table)
    
    
    
        table = ascii_table
    else:
        table = latex_table
    
    
    
    def problem_stats(order=3):
        cl_ctx = cl.create_some_context()
    
    
        with open_output_file("grudge-problem-stats.txt") as outf:
            _, dg_discr_2d = get_strong_wave_op_with_discr_direct(
    
            print("Number of 2D elements:", dg_discr_2d.mesh.nelements, file=outf)
            vol_discr_2d = dg_discr_2d.discr_from_dd("vol")
            dofs_2d = {group.nunit_nodes for group in vol_discr_2d.groups}
            from pytools import one
            print("Number of DOFs per 2D element:", one(dofs_2d), file=outf)
    
            _, dg_discr_3d = get_strong_wave_op_with_discr_direct(
    
            print("Number of 3D elements:", dg_discr_3d.mesh.nelements, file=outf)
            vol_discr_3d = dg_discr_3d.discr_from_dd("vol")
            dofs_3d = {group.nunit_nodes for group in vol_discr_3d.groups}
            from pytools import one
            print("Number of DOFs per 3D element:", one(dofs_3d), file=outf)
    
    
        logger.info("Wrote '%s'", outf.name)
    
    Matt Wala's avatar
    Matt Wala committed
    def statement_counts_table():
        cl_ctx = cl.create_some_context()
        queue = cl.CommandQueue(cl_ctx)
    
        fused_stepper = get_example_stepper(queue, use_fusion=True)
        stepper = get_example_stepper(queue, use_fusion=False)
    
    
        with open_output_file("statement-counts.tex") as outf:
            if not PAPER_OUTPUT:
                print("==== Statement Counts ====", file=outf)
    
    Matt Wala's avatar
    Matt Wala committed
    
    
                "lr",
                ("Operator", "Grudge Node Count"),
                (
                    ("Time integration: baseline",
    
                     r"\num{%d}"
                         % len(stepper.bound_op.eval_code.instructions)),
    
                    ("Right-hand side: baseline",
    
                     r"\num{%d}"
                         % len(stepper.grudge_bound_op.eval_code.instructions)),
    
                    ("Inlined operator",
    
                     r"\num{%d}"
                         % len(fused_stepper.bound_op.eval_code.instructions))
    
        logger.info("Wrote '%s'", outf.name)
    
    
    
    @memoize(key=lambda queue, dims: dims)
    def mem_ops_results(queue, dims):
    
    Matt Wala's avatar
    Matt Wala committed
        fused_stepper = get_example_stepper(
                queue,
    
                dims=dims,
    
    Matt Wala's avatar
    Matt Wala committed
                use_fusion=True,
    
    Matt Wala's avatar
    Matt Wala committed
                exec_mapper_factory=ExecutionMapperWithMemOpCounting)
    
        stepper, ic = get_example_stepper(
                queue,
    
                dims=dims,
    
    Matt Wala's avatar
    Matt Wala committed
                use_fusion=False,
                exec_mapper_factory=ExecutionMapperWithMemOpCounting,
                return_ic=True)
    
        t_start = 0
        dt = 0.02
        t_end = 0.02
    
    
        result = {}
    
    
    Matt Wala's avatar
    Matt Wala committed
        for (_, _, profile_data) in stepper.run(
                ic, t_start, dt, t_end, return_profile_data=True):
            pass
    
    
        result["nonfused_bytes_read"] = profile_data["bytes_read"]
        result["nonfused_bytes_written"] = profile_data["bytes_written"]
        result["nonfused_bytes_total"] = \
                result["nonfused_bytes_read"] \
                + result["nonfused_bytes_written"]
    
        result["nonfused_bytes_read_by_scalar_assignments"] = \
                profile_data["bytes_read_by_scalar_assignments"]
        result["nonfused_bytes_written_by_scalar_assignments"] = \
                profile_data["bytes_written_by_scalar_assignments"]
        result["nonfused_bytes_total_by_scalar_assignments"] = \
                result["nonfused_bytes_read_by_scalar_assignments"] \
                + result["nonfused_bytes_written_by_scalar_assignments"]
    
    Matt Wala's avatar
    Matt Wala committed
    
    
    Matt Wala's avatar
    Matt Wala committed
        for (_, _, profile_data) in fused_stepper.run(
                ic, t_start, dt, t_end, return_profile_data=True):
            pass
    
    
        result["fused_bytes_read"] = profile_data["bytes_read"]
        result["fused_bytes_written"] = profile_data["bytes_written"]
        result["fused_bytes_total"] = \
                result["fused_bytes_read"] \
                + result["fused_bytes_written"]
    
        result["fused_bytes_read_by_scalar_assignments"] = \
                profile_data["bytes_read_by_scalar_assignments"]
        result["fused_bytes_written_by_scalar_assignments"] = \
                profile_data["bytes_written_by_scalar_assignments"]
        result["fused_bytes_total_by_scalar_assignments"] = \
                result["fused_bytes_read_by_scalar_assignments"] \
                + result["fused_bytes_written_by_scalar_assignments"]
    
        return result
    
    
    def scalar_assignment_percent_of_total_mem_ops_table():
        cl_ctx = cl.create_some_context()
        queue = cl.CommandQueue(cl_ctx)
    
        result2d = mem_ops_results(queue, 2)
        result3d = mem_ops_results(queue, 3)
    
    
        with open_output_file("scalar-assignments-mem-op-percentage.tex") as outf:
            if not PAPER_OUTPUT:
                print("==== Scalar Assigment % of Total Mem Ops ====", file=outf)
    
            print(
                table(
                    "lr",
                    ("Operator",
                     r"\parbox{1in}{\centering \% Memory Ops. "
                     r"Due to Scalar Assignments}"),
                    (
                        ("2D: Baseline",
                         "%.1f" % (
                             100 * result2d["nonfused_bytes_total_by_scalar_assignments"]
                             / result2d["nonfused_bytes_total"])),
                        ("2D: Inlined",
                         "%.1f" % (
                             100 * result2d["fused_bytes_total_by_scalar_assignments"]
                             / result2d["fused_bytes_total"])),
                        ("3D: Baseline",
                         "%.1f" % (
                             100 * result3d["nonfused_bytes_total_by_scalar_assignments"]
                             / result3d["nonfused_bytes_total"])),
                        ("3D: Inlined",
                         "%.1f" % (
                             100 * result3d["fused_bytes_total_by_scalar_assignments"]
                             / result3d["fused_bytes_total"])),
                    )),
                file=outf)
    
    Matt Wala's avatar
    Matt Wala committed
    
    
        logger.info("Wrote '%s'", outf.name)
    
    
    
    def scalar_assignment_effect_of_fusion_mem_ops_table():
        cl_ctx = cl.create_some_context()
        queue = cl.CommandQueue(cl_ctx)
    
        result2d = mem_ops_results(queue, 2)
        result3d = mem_ops_results(queue, 3)
    
    
        with open_output_file("scalar-assignments-fusion-impact.tex") as outf:
            if not PAPER_OUTPUT:
                print("==== Scalar Assigment Inlining Impact ====", file=outf)
    
            print(
                table(
                    "lrrrr",
                    ("Operator",
                     r"Bytes Read",
                     r"Bytes Written",
                     r"Total",
                     r"\% of Baseline"),
                    (
                        ("2D: Baseline",
                         r"\num{%d}" % (
                             result2d["nonfused_bytes_read_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result2d["nonfused_bytes_written_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result2d["nonfused_bytes_total_by_scalar_assignments"]),
                         "100"),
                        ("2D: Inlined",
                         r"\num{%d}" % (
                             result2d["fused_bytes_read_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result2d["fused_bytes_written_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result2d["fused_bytes_total_by_scalar_assignments"]),
                         r"%.1f" % (
                             100 * result2d["fused_bytes_total_by_scalar_assignments"]
                             / result2d["nonfused_bytes_total_by_scalar_assignments"])),
                        ("3D: Baseline",
                         r"\num{%d}" % (
                             result3d["nonfused_bytes_read_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result3d["nonfused_bytes_written_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result3d["nonfused_bytes_total_by_scalar_assignments"]),
                         "100"),
                        ("3D: Inlined",
                         r"\num{%d}" % (
                             result3d["fused_bytes_read_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result3d["fused_bytes_written_by_scalar_assignments"]),
                         r"\num{%d}" % (
                             result3d["fused_bytes_total_by_scalar_assignments"]),
                         r"%.1f" % (
                             100 * result3d["fused_bytes_total_by_scalar_assignments"]
                             / result3d["nonfused_bytes_total_by_scalar_assignments"])),
                    )),
                file=outf)
    
        logger.info("Wrote '%s'", outf.name)
    
        import sys
        if len(sys.argv) > 1:
            exec(sys.argv[1])
        else:
            if not SKIP_TESTS:
                # Run tests.
                from py.test import main
                result = main([__file__])
                assert result == 0
    
            # Run examples.
            problem_stats()
            statement_counts_table()
            scalar_assignment_percent_of_total_mem_ops_table()
            scalar_assignment_effect_of_fusion_mem_ops_table()
    
    
    
    if __name__ == "__main__":
        main()