diff --git a/examples/dagrt-fusion.py b/examples/dagrt-fusion.py index a69f936cb3da95c0c6d251997a2c7e7d7e01aafa..66b35eb420b5e7d69d5095a9870d666ec3a99f56 100755 --- a/examples/dagrt-fusion.py +++ b/examples/dagrt-fusion.py @@ -32,6 +32,7 @@ THE SOFTWARE. import logging import numpy as np +import os import six import sys import pyopencl as cl @@ -61,7 +62,16 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -PRINT_RESULTS_TO_STDOUT = True +SKIP_TESTS = int(os.environ.get("SKIP_TESTS", 0)) +PAPER_OUTPUT = int(os.environ.get("PAPER_OUTPUT", 0)) +OUT_DIR = os.environ.get("OUT_DIR", ".") + + +def open_output_file(filename): + if not PAPER_OUTPUT: + return sys.stdout + else: + return open(os.path.join(OUT_DIR, filename), "w") # {{{ topological sort @@ -976,7 +986,7 @@ def ascii_table(table_format, header, rows): return str(table) -if PRINT_RESULTS_TO_STDOUT: +if not PAPER_OUTPUT: table = ascii_table else: table = latex_table @@ -984,20 +994,22 @@ else: def problem_stats(order=3): cl_ctx = cl.create_some_context() + outf = open_output_file("grudge-problem-stats.txt") _, dg_discr_2d = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=order) - print("Number of 2D elements:", dg_discr_2d.mesh.nelements) + print("Number of 2D elements:", dg_discr_2d.mesh.nelements, file=outf) vol_discr_2d = dg_discr_2d.discr_from_dd("vol") dofs_2d = {group.nunit_nodes for group in vol_discr_2d.groups} from pytools import one - print("Number of DOFs per 2D element:", one(dofs_2d)) + print("Number of DOFs per 2D element:", one(dofs_2d), file=outf) _, dg_discr_3d = get_strong_wave_op_with_discr(cl_ctx, dims=3, order=order) - print("Number of 3D elements:", dg_discr_3d.mesh.nelements) + print("Number of 3D elements:", dg_discr_3d.mesh.nelements, file=outf) vol_discr_3d = dg_discr_3d.discr_from_dd("vol") dofs_3d = {group.nunit_nodes for group in vol_discr_3d.groups} from pytools import one - print("Number of DOFs per 3D element:", one(dofs_3d)) + print("Number of DOFs per 3D element:", one(dofs_3d), file=outf) + logger.info("Wrote '%s'", outf.name) def statement_counts_table(): @@ -1007,12 +1019,9 @@ def statement_counts_table(): fused_stepper = get_example_stepper(queue, use_fusion=True) stepper = get_example_stepper(queue, use_fusion=False) - if PRINT_RESULTS_TO_STDOUT: - print("==== Statement Counts ====") - outf = sys.stdout - else: - out_path = "statement-counts.tex" - outf = open(out_path, "w") + outf = open_output_file("statement-counts.tex") + if not PAPER_OUTPUT: + print("==== Statement Counts ====", file=outf) print( table( @@ -1027,6 +1036,7 @@ def statement_counts_table(): r"\num{%d}" % len(fused_stepper.bound_op.eval_code.instructions)) )), file=outf) + logger.info("Wrote '%s'", outf.name) @memoize(key=lambda queue, dims: dims) @@ -1096,12 +1106,9 @@ def scalar_assignment_percent_of_total_mem_ops_table(): result2d = mem_ops_results(queue, 2) result3d = mem_ops_results(queue, 3) - if PRINT_RESULTS_TO_STDOUT: - print("==== Scalar Assigment % of Total Mem Ops ====") - outf = sys.stdout - else: - out_path = "scalar-assignments-mem-op-percentage.tex" - outf = open(out_path, "w") + outf = open_output_file("scalar-assignments-mem-op-percentage.tex") + if not PAPER_OUTPUT: + print("==== Scalar Assigment % of Total Mem Ops ====", file=outf) print( table( @@ -1127,6 +1134,7 @@ def scalar_assignment_percent_of_total_mem_ops_table(): / result3d["fused_bytes_total"])), )), file=outf) + logger.info("Wrote '%s'", outf.name) def scalar_assignment_effect_of_fusion_mem_ops_table(): @@ -1136,12 +1144,9 @@ def scalar_assignment_effect_of_fusion_mem_ops_table(): result2d = mem_ops_results(queue, 2) result3d = mem_ops_results(queue, 3) - if PRINT_RESULTS_TO_STDOUT: - print("==== Scalar Assigment Inlining Impact ====") - outf = sys.stdout - else: - out_path = "scalar-assignments-mem-op-percentage.tex" - outf = open(out_path, "w") + outf = open_output_file("scalar-assignments-fusion-impact.tex") + if not PAPER_OUTPUT: + print("==== Scalar Assigment Inlining Impact ====", file=outf) print( table( @@ -1190,12 +1195,13 @@ def scalar_assignment_effect_of_fusion_mem_ops_table(): / result3d["nonfused_bytes_total_by_scalar_assignments"])), )), file=outf) + logger.info("Wrote '%s'", outf.name) # }}} def main(): - if 1: + if not SKIP_TESTS: # Run tests. from py.test import main result = main([__file__])