diff --git a/test/test_distributed.py b/test/test_distributed.py index 0e67911229952ed3467b8a0279088c6704ca3764..9a008c0e1582ec3578f02d49f9a6658530ddaa00 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -209,6 +209,42 @@ def _do_test_distributed_execution_random_dag(ctx_factory): # }}} +# {{{ test DAG with no comm nodes + +def _test_dag_with_no_comm_nodes_inner(ctx_factory): + from numpy.random import default_rng + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + rng = default_rng() + x_np = rng.random((10, 4)) + x = pt.make_data_wrapper(x_np) + + # {{{ construct the DAG + + out1 = 2 * x + out2 = 4 * out1 + dag = pt.make_dict_of_named_arrays({"out1": out1, "out2": out2}) + + # }}} + + parts = find_distributed_partition(dag) + assert len(parts.parts) == 1 + prg_per_partition = generate_code_for_partition(parts) + out_dict = execute_distributed_partition(parts, prg_per_partition, queue, comm) + + np.testing.assert_allclose(out_dict["out1"], 2 * x_np) + np.testing.assert_allclose(out_dict["out2"], 8 * x_np) + + +def test_dag_with_no_comm_nodes(): + run_test_with_mpi(2, _test_dag_with_no_comm_nodes_inner) + +# }}} + + if __name__ == "__main__": if "RUN_WITHIN_MPI" in os.environ: run_test_with_mpi_inner()