From fb9df62644f62e12a5b42185faeff0976a58a777 Mon Sep 17 00:00:00 2001 From: "[6~" Date: Mon, 18 Nov 2019 18:00:32 -0600 Subject: [PATCH] dagrt fusion example: build DG operator from pieces, without using library functionality --- examples/dagrt-fusion.py | 118 ++++++++++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 15 deletions(-) diff --git a/examples/dagrt-fusion.py b/examples/dagrt-fusion.py index 66b35eb4..c12d6360 100755 --- a/examples/dagrt-fusion.py +++ b/examples/dagrt-fusion.py @@ -435,7 +435,88 @@ def get_strong_wave_op_with_discr(cl_ctx, dims=2, order=4): op.check_bc_coverage(mesh) - return (op, discr) + return (op.sym_operator(), discr) + + +def dg_flux(c, tpair): + u = tpair[0] + v = tpair[1:] + + dims = len(v) + + normal = sym.normal(tpair.dd, dims) + + flux_weak = sym.join_fields( + np.dot(v.avg, normal), + u.avg * normal) + + flux_weak -= (1 if c > 0 else -1)*sym.join_fields( + 0.5*(u.int-u.ext), + 0.5*(normal * np.dot(normal, v.int-v.ext))) + + flux_strong = sym.join_fields( + np.dot(v.int, normal), + u.int * normal) - flux_weak + + return sym.interp(tpair.dd, "all_faces")(c*flux_strong) + + +def get_strong_wave_op_with_discr_direct(cl_ctx, dims=2, order=4): + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dims, + b=(0.5,)*dims, + n=(16,)*dims) + + logger.debug("%d elements", mesh.nelements) + + discr = DGDiscretizationWithBoundaries(cl_ctx, mesh, order=order) + + source_center = np.array([0.1, 0.22, 0.33])[:dims] + source_width = 0.05 + source_omega = 3 + + sym_x = sym.nodes(mesh.dim) + sym_source_center_dist = sym_x - source_center + sym_t = sym.ScalarVariable("t") + + from meshmode.mesh import BTAG_ALL + + c = -0.1 + sign = -1 + + w = sym.make_sym_array("w", dims+1) + u = w[0] + v = w[1:] + + source_f = ( + sym.sin(source_omega*sym_t) + * sym.exp( + -np.dot(sym_source_center_dist, sym_source_center_dist) + / source_width**2)) + + rad_normal = sym.normal(BTAG_ALL, dims) + + rad_u = sym.cse(sym.interp("vol", BTAG_ALL)(u)) + rad_v = sym.cse(sym.interp("vol", BTAG_ALL)(v)) + + rad_bc = sym.cse(sym.join_fields( + 0.5*(rad_u - sign*np.dot(rad_normal, rad_v)), + 0.5*rad_normal*(np.dot(rad_normal, rad_v) - sign*rad_u) + ), "rad_bc") + + sym_operator = ( + - sym.join_fields( + -c*np.dot(sym.nabla(dims), v) - source_f, + -c*(sym.nabla(dims)*u) + ) + + sym.InverseMassOperator()( + sym.FaceMassOperator()( + dg_flux(c, sym.int_tpair(w)) + + dg_flux(c, sym.bv_tpair(BTAG_ALL, w, rad_bc)) + ))) + + return (sym_operator, discr) def get_strong_wave_component(state_component): @@ -452,7 +533,10 @@ def test_stepper_equivalence(ctx_factory, order=4): dims = 2 - op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=order) + sym_operator, _ = get_strong_wave_op_with_discr( + cl_ctx, dims=dims, order=order) + sym_operator_direct, discr = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=dims, order=order) if dims == 2: dt = 0.04 @@ -463,13 +547,13 @@ def test_stepper_equivalence(ctx_factory, order=4): ic = join_fields(discr.zeros(queue), [discr.zeros(queue) for i in range(discr.dim)]) - bound_op = bind(discr, op.sym_operator()) + bound_op = bind(discr, sym_operator) stepper = RK4TimeStepper( queue, discr, "w", bound_op, 1 + discr.dim, get_strong_wave_component) fused_stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator_direct, 1 + discr.dim, get_strong_wave_component) t_start = 0 @@ -1201,17 +1285,21 @@ def scalar_assignment_effect_of_fusion_mem_ops_table(): def main(): - if not SKIP_TESTS: - # Run tests. - from py.test import main - result = main([__file__]) - assert result == 0 - - # Run examples. - problem_stats() - statement_counts_table() - scalar_assignment_percent_of_total_mem_ops_table() - scalar_assignment_effect_of_fusion_mem_ops_table() + import sys + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + if not SKIP_TESTS: + # Run tests. + from py.test import main + result = main([__file__]) + assert result == 0 + + # Run examples. + problem_stats() + statement_counts_table() + scalar_assignment_percent_of_total_mem_ops_table() + scalar_assignment_effect_of_fusion_mem_ops_table() if __name__ == "__main__": -- GitLab