From a6164cf22f0c2ac0646e060a699fe038e1d9626a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 11 Apr 2022 19:30:20 -0500 Subject: [PATCH] Add a helper to enable use of compiled LSRK45 --- examples/wave/wave-op-mpi.py | 52 ++++++++++++++++++++---------------- grudge/shortcuts.py | 33 +++++++++++++++++++++-- test/test_grudge.py | 40 ++++++++++++++------------- 3 files changed, 81 insertions(+), 44 deletions(-) diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index ab6c41e5..e3324044 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -30,7 +30,7 @@ import pyopencl as cl import pyopencl.tools as cl_tools from arraycontext import ( - thaw, freeze, + thaw, with_container_arithmetic, dataclass_array_container ) @@ -45,7 +45,7 @@ from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa from grudge.dof_desc import as_dofdesc, DOFDesc, DISCR_TAG_BASE, DISCR_TAG_QUAD from grudge.trace_pair import TracePair from grudge.discretization import DiscretizationCollection -from grudge.shortcuts import make_visualizer, rk4_step +from grudge.shortcuts import make_visualizer, compiled_lsrk45_step import grudge.op as op @@ -57,7 +57,8 @@ from mpi4py import MPI # {{{ wave equation bits -@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True, + _cls_has_array_context_attr=True) @dataclass_array_container @dataclass(frozen=True) class WaveState: @@ -251,7 +252,8 @@ def main(ctx_factory, dim=2, order=3, c = 1 # FIXME: Sketchy, empirically determined fudge factor - dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c)) + # 5/4 to account for larger LSRK45 stability region + dt = actx.to_numpy(0.45 * estimate_rk4_timestep(actx, dcoll, c)) * 5/4 vis = make_visualizer(dcoll) @@ -271,25 +273,32 @@ def main(ctx_factory, dim=2, order=3, istep = 0 while t < t_final: start = time.time() - if lazy: - fields = thaw(freeze(fields, actx), actx) - fields = rk4_step(fields, t, dt, compiled_rhs) - - l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2)) + fields = compiled_lsrk45_step(actx, fields, t, dt, compiled_rhs) if istep % 10 == 0: stop = time.time() - linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf)) - nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u)) - nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u)) - if comm.rank == 0: - logger.info(f"step: {istep} t: {t} " - f"L2: {l2norm} " - f"Linf: {linfnorm} " - f"sol max: {nodalmax} " - f"sol min: {nodalmin} " - f"wall: {stop-start} ") + if args.no_diagnostics: + if comm.rank == 0: + logger.info(f"step: {istep} t: {t} " + f"wall: {stop-start} ") + else: + l2norm = actx.to_numpy(op.norm(dcoll, fields.u, 2)) + + # NOTE: These are here to ensure the solution is bounded for the + # time interval specified + assert l2norm < 1 + + linfnorm = actx.to_numpy(op.norm(dcoll, fields.u, np.inf)) + nodalmax = actx.to_numpy(op.nodal_max(dcoll, "vol", fields.u)) + nodalmin = actx.to_numpy(op.nodal_min(dcoll, "vol", fields.u)) + if comm.rank == 0: + logger.info(f"step: {istep} t: {t} " + f"L2: {l2norm} " + f"Linf: {linfnorm} " + f"sol max: {nodalmax} " + f"sol min: {nodalmin} " + f"wall: {stop-start} ") if visualize: vis.write_parallel_vtk_file( comm, @@ -304,10 +313,6 @@ def main(ctx_factory, dim=2, order=3, t += dt istep += 1 - # NOTE: These are here to ensure the solution is bounded for the - # time interval specified - assert l2norm < 1 - if __name__ == "__main__": import argparse @@ -320,6 +325,7 @@ if __name__ == "__main__": help="switch to a lazy computation mode") parser.add_argument("--quad", action="store_true") parser.add_argument("--nonaffine", action="store_true") + parser.add_argument("--no-diagnostics", action="store_true") args = parser.parse_args() diff --git a/grudge/shortcuts.py b/grudge/shortcuts.py index 52c77097..0aca64a5 100644 --- a/grudge/shortcuts.py +++ b/grudge/shortcuts.py @@ -1,5 +1,3 @@ -"""Minimal example of a grudge driver.""" - __copyright__ = "Copyright (C) 2009 Andreas Kloeckner" __license__ = """ @@ -23,6 +21,9 @@ THE SOFTWARE. """ +from pytools import memoize_in + + def rk4_step(y, t, h, f): k1 = f(t, y) k2 = f(t+h/2, y + h/2*k1) @@ -31,6 +32,34 @@ def rk4_step(y, t, h, f): return y + h/6*(k1 + 2*k2 + 2*k3 + k4) +def _lsrk45_update(y, a, b, h, rhs_val, residual=0): + residual = a*residual + h*rhs_val + y = y + b * residual + from pytools.obj_array import make_obj_array + return make_obj_array([y, residual]) + + +def compiled_lsrk45_step(actx, y, t, h, f): + from leap.rk import LSRK4MethodBuilder + + @memoize_in(actx, (compiled_lsrk45_step, "update")) + def get_state_updater(): + return actx.compile(_lsrk45_update) + + update = get_state_updater() + + residual = None + + for a, b, c in LSRK4MethodBuilder.coeffs: # pylint: disable=not-an-iterable + rhs_val = f(t + c*h, y) + if residual is None: + y, residual = update(y, a, b, h, rhs_val) + else: + y, residual = update(y, a, b, h, rhs_val, residual) + + return y + + def set_up_rk4(field_var_name, dt, fields, rhs, t_start=0.0): from leap.rk import LSRK4MethodBuilder from dagrt.codegen import PythonCodeGenerator diff --git a/test/test_grudge.py b/test/test_grudge.py index 06f0710c..7e469d46 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -724,6 +724,8 @@ def test_convergence_advec(actx_factory, mesh_name, mesh_pars, op_type, flux_typ def rhs(t, u): return adv_operator.operator(t, u) + compiled_rhs = actx.compile(rhs) + if dim == 3: final_time = 0.1 else: @@ -734,35 +736,35 @@ def test_convergence_advec(actx_factory, mesh_name, mesh_pars, op_type, flux_typ h_max = h_max_from_volume(dcoll, dim=dcoll.ambient_dim) dt = actx.to_numpy(dt_factor * h_max/order**2) nsteps = (final_time // dt) + 1 - dt = final_time/nsteps + 1e-15 - - from grudge.shortcuts import set_up_rk4 - dt_stepper = set_up_rk4("u", dt, u, rhs) - - last_u = None + tol = 1e-14 + dt = final_time/nsteps + tol - from grudge.shortcuts import make_visualizer + from grudge.shortcuts import make_visualizer, compiled_lsrk45_step vis = make_visualizer(dcoll) step = 0 + t = 0 - for event in dt_stepper.run(t_end=final_time): - if isinstance(event, dt_stepper.StateComputed): - step += 1 - logger.debug("[%04d] t = %.5f", step, event.t) + while t < final_time - tol: + step += 1 + logger.debug("[%04d] t = %.5f", step, t) + + u = compiled_lsrk45_step(actx, u, t, dt, compiled_rhs) + + if visualize: + vis.write_vtk_file( + "fld-%s-%04d.vtu" % (mesh_par, step), + [("u", u)] + ) - last_t = event.t - last_u = event.state_component + t += dt - if visualize: - vis.write_vtk_file( - "fld-%s-%04d.vtu" % (mesh_par, step), - [("u", event.state_component)] - ) + if t + dt >= final_time - tol: + dt = final_time-t error_l2 = op.norm( dcoll, - last_u - u_analytic(nodes, t=last_t), + u - u_analytic(nodes, t=t), 2 ) logger.info("h_max %.5e error %.5e", actx.to_numpy(h_max), error_l2) -- GitLab