diff --git a/test/test_distributed.py b/test/test_distributed.py
index 013f3659e4ff593e641d26ef6c46df3aad87f8fd..8d7cd50dd5d5e5f6e1c13b2f01f7a3d659335990 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -25,6 +25,7 @@ import pytest
 from pytools.graph import CycleError
 from pyopencl.tools import (  # noqa
         pytest_generate_tests_for_pyopencl as pytest_generate_tests)
+import pytest  # noqa
 import pyopencl as cl
 import pyopencl.array as cla
 import numpy as np
@@ -36,7 +37,6 @@ import os
 # {{{ mpi test infrastructure
 
 def run_test_with_mpi(num_ranks, f, *args, extra_env_vars=None):
-    import pytest
     pytest.importorskip("mpi4py")
 
     if extra_env_vars is None:
@@ -460,6 +460,67 @@ def test_dag_with_multiple_send_nodes_per_sent_array():
 # }}}
 
 
+# {{{ test DAG with periodic communication
+
+def _test_dag_with_periodic_communication_inner(ctx_factory):
+    from numpy.random import default_rng
+    from mpi4py import MPI  # pylint: disable=import-error
+    comm = MPI.COMM_WORLD
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    # {{{ construct the DAG
+
+    rank = comm.Get_rank()
+    size = comm.Get_size()
+    rng = default_rng()
+    x_np = rng.random((10, 4))
+    x = pt.make_data_wrapper(cla.to_device(queue, x_np))
+
+    x_plus = pt.staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=42,
+            stapled_to=pt.make_distributed_recv(
+                src_rank=(rank+1) % size, comm_tag=42, shape=(10, 4),
+                dtype=np.float64))
+
+    y = x + x_plus
+
+    y_plus = pt.staple_distributed_send(y, dest_rank=(rank-1) % size, comm_tag=43,
+            stapled_to=pt.make_distributed_recv(
+                src_rank=(rank+1) % size, comm_tag=43, shape=(10, 4),
+                dtype=np.float64))
+
+    z = y + y_plus
+
+    dag = pt.make_dict_of_named_arrays({"z": z})
+
+    # }}}
+
+    parts = pt.find_distributed_partition(comm, dag)
+    pt.verify_distributed_partition(comm, parts)
+    prg_per_partition = pt.generate_code_for_partition(parts)
+    out_dict = pt.execute_distributed_partition(
+            parts, prg_per_partition, queue, comm)
+
+    comm.isend(x_np, dest=(rank-1) % size, tag=44)
+    ref_x_plus_np = comm.recv(source=(rank+1) % size, tag=44)
+
+    ref_y_np = x_np + ref_x_plus_np
+
+    comm.isend(ref_y_np, dest=(rank-1) % size, tag=45)
+    ref_y_plus_np = comm.recv(source=(rank+1) % size, tag=45)
+
+    ref_res = ref_y_np + ref_y_plus_np
+
+    np.testing.assert_allclose(out_dict["z"].get(), ref_res)
+
+
+@pytest.mark.parametrize("num_ranks", [2, 3])
+def test_dag_with_periodic_communication(num_ranks):
+    run_test_with_mpi(num_ranks, _test_dag_with_periodic_communication_inner)
+
+# }}}
+
+
 # {{{ test deterministic partitioning
 
 def _gather_random_dist_partitions(ctx_factory):
@@ -573,7 +634,6 @@ def test_verify_distributed_partition():
 def _do_verify_distributed_partition(ctx_factory):
     from mpi4py import MPI  # pylint: disable=import-error
     comm = MPI.COMM_WORLD
-    import pytest
     from pytato.distributed.verify import (DuplicateSendError,
                 DuplicateRecvError, MissingSendError, MissingRecvError)