diff --git a/examples/euler/acoustic_pulse.py b/examples/euler/acoustic_pulse.py index c8fbf9303d50bb99eeee7bbb2aecac44d1a28acb..a525cf41ecd3decec035ef42abf470be4eac0119 100644 --- a/examples/euler/acoustic_pulse.py +++ b/examples/euler/acoustic_pulse.py @@ -36,7 +36,7 @@ from pytools.obj_array import make_obj_array import grudge.op as op from grudge.array_context import PyOpenCLArrayContext, PytatoPyOpenCLArrayContext from grudge.models.euler import ConservedEulerField, EulerOperator, InviscidWallBC -from grudge.shortcuts import rk4_step +from grudge.shortcuts import compiled_lsrk45_step logger = logging.getLogger(__name__) @@ -200,7 +200,7 @@ def run_acoustic_pulse(actx, assert norm_q < 5 fields = actx.thaw(actx.freeze(fields)) - fields = rk4_step(fields, t, dt, compiled_rhs) + fields = compiled_lsrk45_step(actx, fields, t, dt, compiled_rhs) t += dt step += 1 diff --git a/examples/euler/vortex.py b/examples/euler/vortex.py index 2c5f40c6f4234ce8ad5a52e67514487cedb33beb..ca6608d953ec480dfefe4349b319272697b9bab2 100644 --- a/examples/euler/vortex.py +++ b/examples/euler/vortex.py @@ -31,7 +31,7 @@ import pyopencl.tools as cl_tools import grudge.op as op from grudge.array_context import PyOpenCLArrayContext, PytatoPyOpenCLArrayContext from grudge.models.euler import EulerOperator, vortex_initial_condition -from grudge.shortcuts import rk4_step +from grudge.shortcuts import compiled_lsrk45_step logger = logging.getLogger(__name__) @@ -126,6 +126,8 @@ def run_vortex(actx, order=3, resolution=8, final_time=5, vis = make_visualizer(dcoll) + fields = actx.freeze_thaw(fields) + # {{{ time stepping step = 0 @@ -146,8 +148,7 @@ def run_vortex(actx, order=3, resolution=8, final_time=5, ) assert norm_q < 200 - fields = actx.thaw(actx.freeze(fields)) - fields = rk4_step(fields, t, dt, compiled_rhs) + fields = compiled_lsrk45_step(actx, fields, t, dt, compiled_rhs) t += dt step += 1 diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index 61f7559ff48672a8440b84a8c0e29dd38355345e..4d8382080697e82a331263b792ca4dcb9267eb48 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -277,6 +277,8 @@ def main(ctx_factory, dim=2, order=3, import time start = time.time() + fields = actx.freeze_thaw(fields) + t = 0 t_final = 3 istep = 0 diff --git a/grudge/shortcuts.py b/grudge/shortcuts.py index 0c9f638485aa870d6369d274d8155e91c074603a..2367e8886dc87b81c30650b264140d039a37a2d3 100644 --- a/grudge/shortcuts.py +++ b/grudge/shortcuts.py @@ -20,6 +20,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from functools import partial + +from arraycontext import BcastUntilActxArray +from arraycontext.context import ArrayContext from pytools import memoize_in from grudge.dof_desc import DD_VOLUME_ALL @@ -33,19 +37,24 @@ def rk4_step(y, t, h, f): return y + h/6*(k1 + 2*k2 + 2*k3 + k4) -def _lsrk45_update(y, a, b, h, rhs_val, residual=0): - residual = a*residual + h*rhs_val - y = y + b * residual +def _lsrk45_update(actx: ArrayContext, y, a, b, h, rhs_val, residual=None): + bcast = partial(BcastUntilActxArray, actx) + if residual is None: + residual = bcast(h) * rhs_val + else: + residual = bcast(a) * residual + bcast(h) * rhs_val + + y = y + bcast(b) * residual from pytools.obj_array import make_obj_array return make_obj_array([y, residual]) -def compiled_lsrk45_step(actx, y, t, h, f): +def compiled_lsrk45_step(actx: ArrayContext, y, t, h, f): from leap.rk import LSRK4MethodBuilder @memoize_in(actx, (compiled_lsrk45_step, "update")) def get_state_updater(): - return actx.compile(_lsrk45_update) + return actx.compile(partial(_lsrk45_update, actx)) update = get_state_updater() diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 69b4d5c9b92d59b5c9790deb2731aaba92b20f7d..5b66be068d1c27c86d979e26225d42c19be91530 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -37,7 +37,7 @@ from pytools.obj_array import flat_obj_array from grudge import dof_desc, op from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext from grudge.discretization import make_discretization_collection -from grudge.shortcuts import rk4_step +from grudge.shortcuts import compiled_lsrk45_step logger = logging.getLogger(__name__) @@ -246,8 +246,8 @@ def _test_mpi_wave_op_entrypoint(actx, visualize=False): [dcoll.zeros(actx) for i in range(dcoll.dim)] ) - dt = actx.to_numpy( - wave_op.estimate_rk4_timestep(actx, dcoll, fields=fields)) + dt = float(actx.to_numpy( + wave_op.estimate_rk4_timestep(actx, dcoll, fields=fields))) wave_op.check_bc_coverage(local_mesh) @@ -277,10 +277,12 @@ def _test_mpi_wave_op_entrypoint(actx, visualize=False): from grudge.shortcuts import make_visualizer vis = make_visualizer(dcoll) + fields = actx.freeze_thaw(fields) + logmgr.tick_before() for step in range(nsteps): t = step*dt - fields = rk4_step(fields, t=t, h=dt, f=compiled_rhs) + fields = compiled_lsrk45_step(actx, fields, t=t, h=dt, f=compiled_rhs) fields = actx.thaw(actx.freeze(fields)) norm = actx.to_numpy(op.norm(dcoll, fields, 2))