diff --git a/test/test_distributed.py b/test/test_distributed.py index 013f3659e4ff593e641d26ef6c46df3aad87f8fd..8d7cd50dd5d5e5f6e1c13b2f01f7a3d659335990 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -25,6 +25,7 @@ import pytest from pytools.graph import CycleError from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) +import pytest # noqa import pyopencl as cl import pyopencl.array as cla import numpy as np @@ -36,7 +37,6 @@ import os # {{{ mpi test infrastructure def run_test_with_mpi(num_ranks, f, *args, extra_env_vars=None): - import pytest pytest.importorskip("mpi4py") if extra_env_vars is None: @@ -460,6 +460,67 @@ def test_dag_with_multiple_send_nodes_per_sent_array(): # }}} +# {{{ test DAG with periodic communication + +def _test_dag_with_periodic_communication_inner(ctx_factory): + from numpy.random import default_rng + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + # {{{ construct the DAG + + rank = comm.Get_rank() + size = comm.Get_size() + rng = default_rng() + x_np = rng.random((10, 4)) + x = pt.make_data_wrapper(cla.to_device(queue, x_np)) + + x_plus = pt.staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=42, + stapled_to=pt.make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=42, shape=(10, 4), + dtype=np.float64)) + + y = x + x_plus + + y_plus = pt.staple_distributed_send(y, dest_rank=(rank-1) % size, comm_tag=43, + stapled_to=pt.make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=43, shape=(10, 4), + dtype=np.float64)) + + z = y + y_plus + + dag = pt.make_dict_of_named_arrays({"z": z}) + + # }}} + + parts = pt.find_distributed_partition(comm, dag) + pt.verify_distributed_partition(comm, parts) + prg_per_partition = pt.generate_code_for_partition(parts) + out_dict = pt.execute_distributed_partition( + parts, prg_per_partition, queue, comm) + + comm.isend(x_np, dest=(rank-1) % size, tag=44) + ref_x_plus_np = comm.recv(source=(rank+1) % size, tag=44) + + ref_y_np = x_np + ref_x_plus_np + + comm.isend(ref_y_np, dest=(rank-1) % size, tag=45) + ref_y_plus_np = comm.recv(source=(rank+1) % size, tag=45) + + ref_res = ref_y_np + ref_y_plus_np + + np.testing.assert_allclose(out_dict["z"].get(), ref_res) + + +@pytest.mark.parametrize("num_ranks", [2, 3]) +def test_dag_with_periodic_communication(num_ranks): + run_test_with_mpi(num_ranks, _test_dag_with_periodic_communication_inner) + +# }}} + + # {{{ test deterministic partitioning def _gather_random_dist_partitions(ctx_factory): @@ -573,7 +634,6 @@ def test_verify_distributed_partition(): def _do_verify_distributed_partition(ctx_factory): from mpi4py import MPI # pylint: disable=import-error comm = MPI.COMM_WORLD - import pytest from pytato.distributed.verify import (DuplicateSendError, DuplicateRecvError, MissingSendError, MissingRecvError)