diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 707ad83157b3ecaad52b1fe4a24723032cbbc1b1..9405eaafdeb68ffd3053328c46a6b87cea567921 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -78,8 +78,6 @@ jobs:
                 build_py_project_in_conda_env
                 run_examples
 
-                mpirun -n 2 --oversubscribe python distributed.py
-
     docs:
         name: Documentation
         runs-on: ubuntu-latest
diff --git a/examples/distributed.py b/examples/mpi-distributed.py
similarity index 79%
rename from examples/distributed.py
rename to examples/mpi-distributed.py
index 3301f96983a442c2d29852587a6622367e0cab73..dd29a82ab2d0fb6fe4feeb9eee44dde298b13ffa 100644
--- a/examples/distributed.py
+++ b/examples/mpi-distributed.py
@@ -5,6 +5,7 @@ comm = MPI.COMM_WORLD
 
 import pytato as pt
 import pyopencl as cl
+import pyopencl.array as cl_array
 import numpy as np
 
 from pytato import (find_distributed_partition, generate_code_for_partition,
@@ -14,24 +15,36 @@ from pytato import (find_distributed_partition, generate_code_for_partition,
 
 
 def main():
+    ctx = cl.create_some_context()
+    queue = cl.CommandQueue(ctx)
+
     rank = comm.Get_rank()
     size = comm.Get_size()
     rng = np.random.default_rng()
 
     x_in = rng.integers(100, size=(4, 4))
-    x = pt.make_data_wrapper(x_in)
+    x_in_dev = cl_array.to_device(queue, x_in)
+    x = pt.make_data_wrapper(x_in_dev)
+
+    if size < 2:
+        raise RuntimeError("it doesn't make sense to run the "
+                           "distributed-memory test single-rank"
+                           # and self-sends aren't supported for now
+                           )
 
     mytag_x = (main, "x")
-    x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag_x,
-            stapled_to=make_distributed_recv(
-                src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4), dtype=int))
+    x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size,
+            comm_tag=mytag_x, stapled_to=make_distributed_recv(
+                src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4),
+                dtype=int))
 
     y = x+x_plus
 
     mytag_y = (main, "y")
-    y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size, comm_tag=mytag_y,
-            stapled_to=make_distributed_recv(
-                src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4), dtype=int))
+    y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size,
+            comm_tag=mytag_y, stapled_to=make_distributed_recv(
+                src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4),
+                dtype=int))
 
     z = y+y_plus
 
@@ -52,10 +65,6 @@ def main():
         from pytato.visualization import get_dot_graph_from_partition
         get_dot_graph_from_partition(distributed_parts)
 
-    # Execute the distributed partition
-    ctx = cl.create_some_context()
-    queue = cl.CommandQueue(ctx)
-
     pt.verify_distributed_partition(comm, distributed_parts)
 
     context = execute_distributed_partition(distributed_parts, prg_per_partition,