diff --git a/test/test_distributed.py b/test/test_distributed.py index 97a7be549a3598fff89374a484d15e7d341ad69f..802b7e74afeb0c253a0a83cd2e563e54006276d2 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -303,6 +303,62 @@ def test_deterministic_partitioning(seed): # }}} +# {{{ test Kaushik's MWE + +def test_kaushik_mwe(): + run_test_with_mpi(2, _do_test_kaushik_mwe) + + +def _do_test_kaushik_mwe(ctx_factory): + # from https://github.com/inducer/pytato/pull/393#issuecomment-1324642248 + from mpi4py import MPI + + comm = MPI.COMM_WORLD + + if comm.rank == 0: + send_rank = 1 + recv_rank = 1 + + recv = pt.make_distributed_recv( + src_rank=recv_rank, comm_tag=42, + shape=(10,), dtype=np.float64) + y = 2*recv + + send = pt.staple_distributed_send( + y, dest_rank=send_rank, comm_tag=43, + stapled_to=pt.ones(10)) + out = pt.make_dict_of_named_arrays({"out": send}) + elif comm.rank == 1: + send_rank = 0 + recv_rank = 0 + x = pt.make_data_wrapper(np.ones(10)) + + send = pt.staple_distributed_send( + 2*x, dest_rank=send_rank, comm_tag=42, + stapled_to=pt.zeros(10)) + recv = pt.make_distributed_recv( + src_rank=recv_rank, comm_tag=43, + shape=(10,), dtype=np.float64) + out = pt.make_dict_of_named_arrays({"out1": send, "out2": recv}) + else: + raise AssertionError() + + distributed_parts = pt.find_distributed_partition(comm, out) + + pt.verify_distributed_partition(comm, distributed_parts) + prg_per_partition = pt.generate_code_for_partition(distributed_parts) + + # Execute the distributed partition + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + pt.execute_distributed_partition(distributed_parts, prg_per_partition, + queue, comm, + input_args={}) + +# }}} + + # {{{ test verify_distributed_partition def test_verify_distributed_partition():