diff --git a/test/test_distributed.py b/test/test_distributed.py
index 81468f21ffa8df35f7568c9fe306b614e80bb3be..15452294d416209ca5e640928693572543821d25 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -26,6 +26,7 @@ from pytools.graph import CycleError
 from pyopencl.tools import (  # noqa
         pytest_generate_tests_for_pyopencl as pytest_generate_tests)
 import pyopencl as cl
+import pyopencl.array as cla
 import numpy as np
 import pytato as pt
 import sys
@@ -338,6 +339,59 @@ def test_dag_with_recv_as_output():
 # }}}
 
 
+# {{{ test DAG with multiple send nodes per sent array
+
+def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory):
+    from numpy.random import default_rng
+    from mpi4py import MPI  # pylint: disable=import-error
+    comm = MPI.COMM_WORLD
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    # {{{ construct the DAG
+
+    if comm.rank == 0:
+        rng = default_rng()
+        x_np = rng.random((10, 4))
+        x = pt.make_data_wrapper(cla.to_device(queue, x_np))
+        y = 2 * x
+        send1 = pt.staple_distributed_send(
+            y, dest_rank=1, comm_tag=42,
+            stapled_to=pt.ones(10))
+        send2 = pt.staple_distributed_send(
+            y, dest_rank=2, comm_tag=42,
+            stapled_to=pt.ones(10))
+        z = 4 * y
+        dag = pt.make_dict_of_named_arrays({"z": z, "send1": send1, "send2": send2})
+    else:
+        y = pt.make_distributed_recv(
+            src_rank=0, comm_tag=42,
+            shape=(10, 4), dtype=np.float64)
+        z = 4 * y
+        dag = pt.make_dict_of_named_arrays({"z": z})
+
+    # }}}
+
+    parts = pt.find_distributed_partition(comm, dag)
+    pt.verify_distributed_partition(comm, parts)
+    prg_per_partition = pt.generate_code_for_partition(parts)
+    out_dict = pt.execute_distributed_partition(
+            parts, prg_per_partition, queue, comm)
+
+    if comm.rank == 0:
+        comm.bcast(x_np)
+    else:
+        x_np = comm.bcast(None)
+
+    np.testing.assert_allclose(out_dict["z"].get(), 8 * x_np)
+
+
+def test_dag_with_multiple_send_nodes_per_sent_array():
+    run_test_with_mpi(3, _test_dag_with_multiple_send_nodes_per_sent_array_inner)
+
+# }}}
+
+
 # {{{ test deterministic partitioning
 
 def _gather_random_dist_partitions(ctx_factory):