diff --git a/examples/dagrt-fusion.py b/examples/dagrt-fusion.py index a69f936cb3da95c0c6d251997a2c7e7d7e01aafa..e0f35b02f96303875ca40521663268ea0f491a5d 100755 --- a/examples/dagrt-fusion.py +++ b/examples/dagrt-fusion.py @@ -32,6 +32,7 @@ THE SOFTWARE. import logging import numpy as np +import os import six import sys import pyopencl as cl @@ -41,7 +42,7 @@ import pytest import dagrt.language as lang import pymbolic.primitives as p import grudge.symbolic.mappers as gmap -import grudge.symbolic.operators as op +import grudge.symbolic.operators as sym_op from grudge.execution import ExecutionMapper from grudge.function_registry import base_function_registry from pymbolic.mapper import Mapper @@ -61,7 +62,16 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -PRINT_RESULTS_TO_STDOUT = True +SKIP_TESTS = int(os.environ.get("SKIP_TESTS", 0)) +PAPER_OUTPUT = int(os.environ.get("PAPER_OUTPUT", 0)) +OUT_DIR = os.environ.get("OUT_DIR", ".") + + +def open_output_file(filename): + if not PAPER_OUTPUT: + return sys.stdout + else: + return open(os.path.join(OUT_DIR, filename), "w") # {{{ topological sort @@ -425,7 +435,88 @@ def get_strong_wave_op_with_discr(cl_ctx, dims=2, order=4): op.check_bc_coverage(mesh) - return (op, discr) + return (op.sym_operator(), discr) + + +def dg_flux(c, tpair): + u = tpair[0] + v = tpair[1:] + + dims = len(v) + + normal = sym.normal(tpair.dd, dims) + + flux_weak = sym.join_fields( + np.dot(v.avg, normal), + u.avg * normal) + + flux_weak -= (1 if c > 0 else -1)*sym.join_fields( + 0.5*(u.int-u.ext), + 0.5*(normal * np.dot(normal, v.int-v.ext))) + + flux_strong = sym.join_fields( + np.dot(v.int, normal), + u.int * normal) - flux_weak + + return sym.interp(tpair.dd, "all_faces")(c*flux_strong) + + +def get_strong_wave_op_with_discr_direct(cl_ctx, dims=2, order=4): + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dims, + b=(0.5,)*dims, + n=(16,)*dims) + + logger.debug("%d elements", mesh.nelements) + + discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=order) + + source_center = np.array([0.1, 0.22, 0.33])[:dims] + source_width = 0.05 + source_omega = 3 + + sym_x = sym.nodes(mesh.dim) + sym_source_center_dist = sym_x - source_center + sym_t = sym.ScalarVariable("t") + + from meshmode.mesh import BTAG_ALL + + c = -0.1 + sign = -1 + + w = sym.make_sym_array("w", dims+1) + u = w[0] + v = w[1:] + + source_f = ( + sym.sin(source_omega*sym_t) + * sym.exp( + -np.dot(sym_source_center_dist, sym_source_center_dist) + / source_width**2)) + + rad_normal = sym.normal(BTAG_ALL, dims) + + rad_u = sym.cse(sym.interp("vol", BTAG_ALL)(u)) + rad_v = sym.cse(sym.interp("vol", BTAG_ALL)(v)) + + rad_bc = sym.cse(sym.join_fields( + 0.5*(rad_u - sign*np.dot(rad_normal, rad_v)), + 0.5*rad_normal*(np.dot(rad_normal, rad_v) - sign*rad_u) + ), "rad_bc") + + sym_operator = ( + - sym.join_fields( + -c*np.dot(sym.nabla(dims), v) - source_f, + -c*(sym.nabla(dims)*u) + ) + + sym.InverseMassOperator()( + sym.FaceMassOperator()( + dg_flux(c, sym.int_tpair(w)) + + dg_flux(c, sym.bv_tpair(BTAG_ALL, w, rad_bc)) + ))) + + return (sym_operator, discr) def get_strong_wave_component(state_component): @@ -442,7 +533,10 @@ def test_stepper_equivalence(ctx_factory, order=4): dims = 2 - op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=order) + sym_operator, _ = get_strong_wave_op_with_discr( + cl_ctx, dims=dims, order=order) + sym_operator_direct, discr = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=dims, order=order) if dims == 2: dt = 0.04 @@ -453,13 +547,13 @@ def test_stepper_equivalence(ctx_factory, order=4): ic = join_fields(discr.zeros(queue), [discr.zeros(queue) for i in range(discr.dim)]) - bound_op = bind(discr, op.sym_operator()) + bound_op = bind(discr, sym_operator) stepper = RK4TimeStepper( queue, discr, "w", bound_op, 1 + discr.dim, get_strong_wave_component) fused_stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator_direct, 1 + discr.dim, get_strong_wave_component) t_start = 0 @@ -554,10 +648,10 @@ class ExecutionMapperWithMemOpCounting(ExecutionMapperWrapper): expr.op, ( # TODO: Not comprehensive. - op.InterpolationOperator, - op.RefFaceMassOperator, - op.RefInverseMassOperator, - op.OppositeInteriorFaceSwap)): + sym_op.InterpolationOperator, + sym_op.RefFaceMassOperator, + sym_op.RefInverseMassOperator, + sym_op.OppositeInteriorFaceSwap)): val = self.map_profiled_essentially_elementwise_linear( expr.op, expr.field, profile_data) @@ -667,7 +761,7 @@ def test_assignment_memory_model(ctx_factory): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) - _, discr = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=3) + _, discr = get_strong_wave_op_with_discr_direct(cl_ctx, dims=2, order=3) # Assignment instruction bound_op = bind( @@ -696,7 +790,8 @@ def test_stepper_mem_ops(ctx_factory, use_fusion): dims = 2 - op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=3) + sym_operator, discr = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=dims, order=3) t_start = 0 dt = 0.04 @@ -708,7 +803,7 @@ def test_stepper_mem_ops(ctx_factory, use_fusion): if not use_fusion: bound_op = bind( - discr, op.sym_operator(), + discr, sym_operator, exec_mapper_factory=ExecutionMapperWithMemOpCounting) stepper = RK4TimeStepper( @@ -718,7 +813,7 @@ def test_stepper_mem_ops(ctx_factory, use_fusion): else: stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator, 1 + discr.dim, get_strong_wave_component, exec_mapper_factory=ExecutionMapperWithMemOpCounting) @@ -866,7 +961,8 @@ def test_stepper_timing(ctx_factory, use_fusion): dims = 3 - op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=3) + sym_operator, discr = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=dims, order=3) t_start = 0 dt = 0.04 @@ -878,7 +974,7 @@ def test_stepper_timing(ctx_factory, use_fusion): if not use_fusion: bound_op = bind( - discr, op.sym_operator(), + discr, sym_operator, exec_mapper_factory=ExecutionMapperWithTiming) stepper = RK4TimeStepper( @@ -888,7 +984,7 @@ def test_stepper_timing(ctx_factory, use_fusion): else: stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator, 1 + discr.dim, get_strong_wave_component, exec_mapper_factory=ExecutionMapperWithTiming) @@ -917,11 +1013,12 @@ def test_stepper_timing(ctx_factory, use_fusion): def get_example_stepper(queue, dims=2, order=3, use_fusion=True, exec_mapper_factory=ExecutionMapper, return_ic=False): - op, discr = get_strong_wave_op_with_discr(queue.context, dims=dims, order=3) + sym_operator, discr = get_strong_wave_op_with_discr_direct( + queue.context, dims=dims, order=3) if not use_fusion: bound_op = bind( - discr, op.sym_operator(), + discr, sym_operator, exec_mapper_factory=exec_mapper_factory) stepper = RK4TimeStepper( @@ -931,7 +1028,7 @@ def get_example_stepper(queue, dims=2, order=3, use_fusion=True, else: stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator, 1 + discr.dim, get_strong_wave_component, exec_mapper_factory=exec_mapper_factory) @@ -976,7 +1073,7 @@ def ascii_table(table_format, header, rows): return str(table) -if PRINT_RESULTS_TO_STDOUT: +if not PAPER_OUTPUT: table = ascii_table else: table = latex_table @@ -984,20 +1081,24 @@ else: def problem_stats(order=3): cl_ctx = cl.create_some_context() + outf = open_output_file("grudge-problem-stats.txt") - _, dg_discr_2d = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=order) - print("Number of 2D elements:", dg_discr_2d.mesh.nelements) + _, dg_discr_2d = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=2, order=order) + print("Number of 2D elements:", dg_discr_2d.mesh.nelements, file=outf) vol_discr_2d = dg_discr_2d.discr_from_dd("vol") dofs_2d = {group.nunit_nodes for group in vol_discr_2d.groups} from pytools import one - print("Number of DOFs per 2D element:", one(dofs_2d)) + print("Number of DOFs per 2D element:", one(dofs_2d), file=outf) - _, dg_discr_3d = get_strong_wave_op_with_discr(cl_ctx, dims=3, order=order) - print("Number of 3D elements:", dg_discr_3d.mesh.nelements) + _, dg_discr_3d = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=3, order=order) + print("Number of 3D elements:", dg_discr_3d.mesh.nelements, file=outf) vol_discr_3d = dg_discr_3d.discr_from_dd("vol") dofs_3d = {group.nunit_nodes for group in vol_discr_3d.groups} from pytools import one - print("Number of DOFs per 3D element:", one(dofs_3d)) + print("Number of DOFs per 3D element:", one(dofs_3d), file=outf) + logger.info("Wrote '%s'", outf.name) def statement_counts_table(): @@ -1007,12 +1108,9 @@ def statement_counts_table(): fused_stepper = get_example_stepper(queue, use_fusion=True) stepper = get_example_stepper(queue, use_fusion=False) - if PRINT_RESULTS_TO_STDOUT: - print("==== Statement Counts ====") - outf = sys.stdout - else: - out_path = "statement-counts.tex" - outf = open(out_path, "w") + outf = open_output_file("statement-counts.tex") + if not PAPER_OUTPUT: + print("==== Statement Counts ====", file=outf) print( table( @@ -1027,6 +1125,7 @@ def statement_counts_table(): r"\num{%d}" % len(fused_stepper.bound_op.eval_code.instructions)) )), file=outf) + logger.info("Wrote '%s'", outf.name) @memoize(key=lambda queue, dims: dims) @@ -1096,12 +1195,9 @@ def scalar_assignment_percent_of_total_mem_ops_table(): result2d = mem_ops_results(queue, 2) result3d = mem_ops_results(queue, 3) - if PRINT_RESULTS_TO_STDOUT: - print("==== Scalar Assigment % of Total Mem Ops ====") - outf = sys.stdout - else: - out_path = "scalar-assignments-mem-op-percentage.tex" - outf = open(out_path, "w") + outf = open_output_file("scalar-assignments-mem-op-percentage.tex") + if not PAPER_OUTPUT: + print("==== Scalar Assigment % of Total Mem Ops ====", file=outf) print( table( @@ -1127,6 +1223,7 @@ def scalar_assignment_percent_of_total_mem_ops_table(): / result3d["fused_bytes_total"])), )), file=outf) + logger.info("Wrote '%s'", outf.name) def scalar_assignment_effect_of_fusion_mem_ops_table(): @@ -1136,12 +1233,9 @@ def scalar_assignment_effect_of_fusion_mem_ops_table(): result2d = mem_ops_results(queue, 2) result3d = mem_ops_results(queue, 3) - if PRINT_RESULTS_TO_STDOUT: - print("==== Scalar Assigment Inlining Impact ====") - outf = sys.stdout - else: - out_path = "scalar-assignments-mem-op-percentage.tex" - outf = open(out_path, "w") + outf = open_output_file("scalar-assignments-fusion-impact.tex") + if not PAPER_OUTPUT: + print("==== Scalar Assigment Inlining Impact ====", file=outf) print( table( @@ -1190,23 +1284,30 @@ def scalar_assignment_effect_of_fusion_mem_ops_table(): / result3d["nonfused_bytes_total_by_scalar_assignments"])), )), file=outf) + logger.info("Wrote '%s'", outf.name) # }}} def main(): - if 1: - # Run tests. - from py.test import main - result = main([__file__]) - assert result == 0 + import sys + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + if not SKIP_TESTS: + # Run tests. + from py.test import main + result = main([__file__]) + assert result == 0 - # Run examples. - problem_stats() - statement_counts_table() - scalar_assignment_percent_of_total_mem_ops_table() - scalar_assignment_effect_of_fusion_mem_ops_table() + # Run examples. + problem_stats() + statement_counts_table() + scalar_assignment_percent_of_total_mem_ops_table() + scalar_assignment_effect_of_fusion_mem_ops_table() if __name__ == "__main__": main() + +# vim: foldmethod=marker