diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 707ad83157b3ecaad52b1fe4a24723032cbbc1b1..9405eaafdeb68ffd3053328c46a6b87cea567921 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,8 +78,6 @@ jobs: build_py_project_in_conda_env run_examples - mpirun -n 2 --oversubscribe python distributed.py - docs: name: Documentation runs-on: ubuntu-latest diff --git a/examples/distributed.py b/examples/mpi-distributed.py similarity index 79% rename from examples/distributed.py rename to examples/mpi-distributed.py index 3301f96983a442c2d29852587a6622367e0cab73..dd29a82ab2d0fb6fe4feeb9eee44dde298b13ffa 100644 --- a/examples/distributed.py +++ b/examples/mpi-distributed.py @@ -5,6 +5,7 @@ comm = MPI.COMM_WORLD import pytato as pt import pyopencl as cl +import pyopencl.array as cl_array import numpy as np from pytato import (find_distributed_partition, generate_code_for_partition, @@ -14,24 +15,36 @@ from pytato import (find_distributed_partition, generate_code_for_partition, def main(): + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + rank = comm.Get_rank() size = comm.Get_size() rng = np.random.default_rng() x_in = rng.integers(100, size=(4, 4)) - x = pt.make_data_wrapper(x_in) + x_in_dev = cl_array.to_device(queue, x_in) + x = pt.make_data_wrapper(x_in_dev) + + if size < 2: + raise RuntimeError("it doesn't make sense to run the " + "distributed-memory test single-rank" + # and self-sends aren't supported for now + ) mytag_x = (main, "x") - x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag_x, - stapled_to=make_distributed_recv( - src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4), dtype=int)) + x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size, + comm_tag=mytag_x, stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4), + dtype=int)) y = x+x_plus mytag_y = (main, "y") - y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size, comm_tag=mytag_y, - stapled_to=make_distributed_recv( - src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4), dtype=int)) + y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size, + comm_tag=mytag_y, stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4), + dtype=int)) z = y+y_plus @@ -52,10 +65,6 @@ def main(): from pytato.visualization import get_dot_graph_from_partition get_dot_graph_from_partition(distributed_parts) - # Execute the distributed partition - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - pt.verify_distributed_partition(comm, distributed_parts) context = execute_distributed_partition(distributed_parts, prg_per_partition,