diff --git a/test/test_distributed.py b/test/test_distributed.py index 81468f21ffa8df35f7568c9fe306b614e80bb3be..15452294d416209ca5e640928693572543821d25 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -26,6 +26,7 @@ from pytools.graph import CycleError from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) import pyopencl as cl +import pyopencl.array as cla import numpy as np import pytato as pt import sys @@ -338,6 +339,59 @@ def test_dag_with_recv_as_output(): # }}} +# {{{ test DAG with multiple send nodes per sent array + +def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory): + from numpy.random import default_rng + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + # {{{ construct the DAG + + if comm.rank == 0: + rng = default_rng() + x_np = rng.random((10, 4)) + x = pt.make_data_wrapper(cla.to_device(queue, x_np)) + y = 2 * x + send1 = pt.staple_distributed_send( + y, dest_rank=1, comm_tag=42, + stapled_to=pt.ones(10)) + send2 = pt.staple_distributed_send( + y, dest_rank=2, comm_tag=42, + stapled_to=pt.ones(10)) + z = 4 * y + dag = pt.make_dict_of_named_arrays({"z": z, "send1": send1, "send2": send2}) + else: + y = pt.make_distributed_recv( + src_rank=0, comm_tag=42, + shape=(10, 4), dtype=np.float64) + z = 4 * y + dag = pt.make_dict_of_named_arrays({"z": z}) + + # }}} + + parts = pt.find_distributed_partition(comm, dag) + pt.verify_distributed_partition(comm, parts) + prg_per_partition = pt.generate_code_for_partition(parts) + out_dict = pt.execute_distributed_partition( + parts, prg_per_partition, queue, comm) + + if comm.rank == 0: + comm.bcast(x_np) + else: + x_np = comm.bcast(None) + + np.testing.assert_allclose(out_dict["z"].get(), 8 * x_np) + + +def test_dag_with_multiple_send_nodes_per_sent_array(): + run_test_with_mpi(3, _test_dag_with_multiple_send_nodes_per_sent_array_inner) + +# }}} + + # {{{ test deterministic partitioning def _gather_random_dist_partitions(ctx_factory):