diff --git a/test/test_distributed.py b/test/test_distributed.py index 3ac44ad4c158cf64d761e6dd7dcb5fc1027c3920..0e67911229952ed3467b8a0279088c6704ca3764 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -166,10 +166,10 @@ def _do_test_distributed_execution_random_dag(ctx_factory): additional_generators=[ (comm_fake_prob, gen_comm) ]) - x_comm = make_random_dag(rdagc_comm) + pt_dag = pt.DictOfNamedArrays({"result": make_random_dag(rdagc_comm)}) + x_comm = pt.transform.materialize_with_mpms(pt_dag) - distributed_partition = find_distributed_partition( - pt.DictOfNamedArrays({"result": x_comm})) + distributed_partition = find_distributed_partition(x_comm) # Transform symbolic tags into numeric ones for MPI distributed_partition, _new_mpi_base_tag = number_distributed_tags(