diff --git a/test/test_distributed.py b/test/test_distributed.py
index 97a7be549a3598fff89374a484d15e7d341ad69f..802b7e74afeb0c253a0a83cd2e563e54006276d2 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -303,6 +303,62 @@ def test_deterministic_partitioning(seed):
 # }}}
 
 
+# {{{ test Kaushik's MWE
+
+def test_kaushik_mwe():
+    run_test_with_mpi(2, _do_test_kaushik_mwe)
+
+
+def _do_test_kaushik_mwe(ctx_factory):
+    # from https://github.com/inducer/pytato/pull/393#issuecomment-1324642248
+    from mpi4py import MPI
+
+    comm = MPI.COMM_WORLD
+
+    if comm.rank == 0:
+        send_rank = 1
+        recv_rank = 1
+
+        recv = pt.make_distributed_recv(
+                    src_rank=recv_rank, comm_tag=42,
+                    shape=(10,), dtype=np.float64)
+        y = 2*recv
+
+        send = pt.staple_distributed_send(
+            y, dest_rank=send_rank, comm_tag=43,
+            stapled_to=pt.ones(10))
+        out = pt.make_dict_of_named_arrays({"out": send})
+    elif comm.rank == 1:
+        send_rank = 0
+        recv_rank = 0
+        x = pt.make_data_wrapper(np.ones(10))
+
+        send = pt.staple_distributed_send(
+            2*x, dest_rank=send_rank, comm_tag=42,
+            stapled_to=pt.zeros(10))
+        recv = pt.make_distributed_recv(
+                    src_rank=recv_rank, comm_tag=43,
+                    shape=(10,), dtype=np.float64)
+        out = pt.make_dict_of_named_arrays({"out1": send, "out2": recv})
+    else:
+        raise AssertionError()
+
+    distributed_parts = pt.find_distributed_partition(comm, out)
+
+    pt.verify_distributed_partition(comm, distributed_parts)
+    prg_per_partition = pt.generate_code_for_partition(distributed_parts)
+
+    # Execute the distributed partition
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    pt.execute_distributed_partition(distributed_parts, prg_per_partition,
+                                            queue, comm,
+                                            input_args={})
+
+# }}}
+
+
 # {{{ test verify_distributed_partition
 
 def test_verify_distributed_partition():