From ae93d6448460dc92235bea3a7b97e7d2a0fbc780 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 31 Jan 2023 15:27:48 -0600 Subject: [PATCH] add test for materialized arrays promoted to part outputs --- test/test_distributed.py | 68 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/test/test_distributed.py b/test/test_distributed.py index 1545229..013f365 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -339,6 +339,74 @@ def test_dag_with_recv_as_output(): # }}} +# {{{ test DAG with a materialized array promoted to a part output + +def _test_dag_with_materialized_array_promoted_to_part_output_inner(ctx_factory): + from numpy.random import default_rng + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + # {{{ construct the DAG + + if comm.rank == 0: + rng = default_rng() + x_np = rng.random((10, 4)) + x = pt.make_data_wrapper(cla.to_device(queue, x_np)) + y = 2 * x + # Force y to be materialized + from pytato.tags import ImplStored + y = y.tagged(ImplStored()) + z = 2 * y + send = pt.staple_distributed_send( + z, dest_rank=1, comm_tag=42, + stapled_to=pt.ones(10)) + w = pt.make_distributed_recv( + src_rank=1, comm_tag=43, + shape=(10, 4), dtype=np.float64) + q = y + w + dag = pt.make_dict_of_named_arrays({"q": q, "send": send}) + else: + z = pt.make_distributed_recv( + src_rank=0, comm_tag=42, + shape=(10, 4), dtype=np.float64) + w = 2 * z + send = pt.staple_distributed_send( + w, dest_rank=0, comm_tag=43, + stapled_to=pt.ones(10)) + q = z/2 + w + dag = pt.make_dict_of_named_arrays({"q": q, "send": send}) + + # }}} + + parts = pt.find_distributed_partition(comm, dag) + prg_per_partition = pt.generate_code_for_partition(parts) + out_dict = pt.execute_distributed_partition( + parts, prg_per_partition, queue, comm) + + if comm.rank == 0: + # Is this too fragile? + # part 0 should return a materialized array and a send + assert len(parts.parts[0].output_names) == 2 + # part 1 should take a materialized array and a recv + assert len(parts.parts[1].partition_input_names) == 2 + + if comm.rank == 0: + comm.bcast(x_np) + else: + x_np = comm.bcast(None) + + np.testing.assert_allclose(out_dict["q"].get(), 10 * x_np) + + +def test_dag_with_materialized_array_promoted_to_part_output(): + run_test_with_mpi( + 2, _test_dag_with_materialized_array_promoted_to_part_output_inner) + +# }}} + + # {{{ test DAG with multiple send nodes per sent array def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory): -- GitLab