From babbbb625cb06b7792e56e20f7f18a007aa060d3 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 18 Jan 2023 13:59:07 -0600 Subject: [PATCH] Force distributed example to run multi-rank Enabled by https://github.com/inducer/ci-support/commit/f941e157312856fdd8f6ece65eac9347c1932b9c Co-authored-by: Andreas Kloeckner --- .github/workflows/ci.yml | 2 -- .../{distributed.py => mpi-distributed.py} | 31 ++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) rename examples/{distributed.py => mpi-distributed.py} (79%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 707ad83..9405eaa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,8 +78,6 @@ jobs: build_py_project_in_conda_env run_examples - mpirun -n 2 --oversubscribe python distributed.py - docs: name: Documentation runs-on: ubuntu-latest diff --git a/examples/distributed.py b/examples/mpi-distributed.py similarity index 79% rename from examples/distributed.py rename to examples/mpi-distributed.py index 3301f96..dd29a82 100644 --- a/examples/distributed.py +++ b/examples/mpi-distributed.py @@ -5,6 +5,7 @@ comm = MPI.COMM_WORLD import pytato as pt import pyopencl as cl +import pyopencl.array as cl_array import numpy as np from pytato import (find_distributed_partition, generate_code_for_partition, @@ -14,24 +15,36 @@ from pytato import (find_distributed_partition, generate_code_for_partition, def main(): + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + rank = comm.Get_rank() size = comm.Get_size() rng = np.random.default_rng() x_in = rng.integers(100, size=(4, 4)) - x = pt.make_data_wrapper(x_in) + x_in_dev = cl_array.to_device(queue, x_in) + x = pt.make_data_wrapper(x_in_dev) + + if size < 2: + raise RuntimeError("it doesn't make sense to run the " + "distributed-memory test single-rank" + # and self-sends aren't supported for now + ) mytag_x = (main, "x") - x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag_x, - stapled_to=make_distributed_recv( - src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4), dtype=int)) + x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size, + comm_tag=mytag_x, stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4), + dtype=int)) y = x+x_plus mytag_y = (main, "y") - y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size, comm_tag=mytag_y, - stapled_to=make_distributed_recv( - src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4), dtype=int)) + y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size, + comm_tag=mytag_y, stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4), + dtype=int)) z = y+y_plus @@ -52,10 +65,6 @@ def main(): from pytato.visualization import get_dot_graph_from_partition get_dot_graph_from_partition(distributed_parts) - # Execute the distributed partition - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - pt.verify_distributed_partition(comm, distributed_parts) context = execute_distributed_partition(distributed_parts, prg_per_partition, -- GitLab