From 94d1efbfc84a24ef0bf7f6df5ede5c71aca4868a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 18 Jan 2023 10:11:01 -0600 Subject: [PATCH] add test for dag with duplicated output array --- test/test_distributed.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/test/test_distributed.py b/test/test_distributed.py index 802b7e7..34c1b46 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -255,6 +255,42 @@ def test_dag_with_no_comm_nodes(): # }}} +# {{{ test DAG with duplicated output arrays + +def _test_dag_with_duplicated_output_arrays_inner(ctx_factory): + from numpy.random import default_rng + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + rng = default_rng() + x_np = rng.random((10, 4)) + x = pt.make_data_wrapper(x_np) + + # {{{ construct the DAG + + y = 2 * x + out = 4 * y + dag = pt.make_dict_of_named_arrays({"out1": out, "out2": out}) + + # }}} + + parts = pt.find_distributed_partition(comm, dag) + prg_per_partition = pt.generate_code_for_partition(parts) + out_dict = pt.execute_distributed_partition( + parts, prg_per_partition, queue, comm) + + np.testing.assert_allclose(out_dict["out1"], 8 * x_np) + np.testing.assert_allclose(out_dict["out2"], 8 * x_np) + + +def test_dag_with_duplicated_output_arrays(): + run_test_with_mpi(2, _test_dag_with_duplicated_output_arrays_inner) + +# }}} + + # {{{ test deterministic partitioning def _gather_random_dist_partitions(ctx_factory): -- GitLab