From 96282ec96b4fe3297a5a80faacf9bcea85a7dd45 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 20 Mar 2022 09:44:22 -0500 Subject: [PATCH] Adds example to demonstrate eager send scheduling --- examples/demo_distributed_node_duplication.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 examples/demo_distributed_node_duplication.py diff --git a/examples/demo_distributed_node_duplication.py b/examples/demo_distributed_node_duplication.py new file mode 100644 index 0000000..39307cc --- /dev/null +++ b/examples/demo_distributed_node_duplication.py @@ -0,0 +1,37 @@ +""" +An example to demonstrate the behavior of +:func:`pytato.find_distrbuted_partition`. One of the key characteristic of the +partitioning routine is to recompute expressions that appear in the multiple +partitions but are not materialized. +""" +import pytato as pt +import numpy as np + +size = 2 +rank = 0 + +x1 = pt.make_placeholder("x1", shape=(10, 4), dtype=np.float64) +x2 = pt.make_placeholder("x2", shape=(10, 4), dtype=np.float64) +x3 = pt.make_placeholder("x3", shape=(10, 4), dtype=np.float64) +x4 = pt.make_placeholder("x4", shape=(10, 4), dtype=np.float64) + + +tmp1 = (x1 + x2).tagged(pt.tags.ImplStored()) +tmp2 = tmp1 + x3 +# "marking" *tmp2* so that its duplication can be clearly visualized. +tmp2 = tmp2.tagged(pt.tags.Named("tmp2")) +tmp3 = (2 * x4).tagged(pt.tags.ImplStored()) +tmp4 = tmp2 + tmp3 + +recv = pt.staple_distributed_send(tmp4, dest_rank=(rank-1) % size, comm_tag=10, + stapled_to=pt.make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=10, shape=(10, 4), dtype=int)) + +out = tmp2 + recv +result = pt.make_dict_of_named_arrays({"out": out}) + +partitions = pt.find_distributed_partition(result) + +# Visualize *partitions* to see that each of the two partitions contains a node +# named 'tmp2'. +pt.show_dot_graph(partitions) -- GitLab