From 691d38ada09b89bb02e379f23f59b02e98bf0370 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Tue, 7 Nov 2023 10:16:03 -0600
Subject: [PATCH] make MPI tags deterministic (gh-462)

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
---
 pytato/distributed/__init__.py |  3 ++-
 pytato/distributed/tags.py     | 15 +++++++++++----
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/pytato/distributed/__init__.py b/pytato/distributed/__init__.py
index 04b368a..4354b2f 100644
--- a/pytato/distributed/__init__.py
+++ b/pytato/distributed/__init__.py
@@ -22,7 +22,8 @@ Internal stuff that is only here because the documentation tool wants it
 
 .. class:: CommTagType
 
-    A type representing a communication tag.
+    A type representing a communication tag. Communication tags must be
+    hashable and totally ordered (and hence comparable).
 
 .. class:: ShapeType
 
diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py
index 9e3bde8..41ae327 100644
--- a/pytato/distributed/tags.py
+++ b/pytato/distributed/tags.py
@@ -31,7 +31,7 @@ THE SOFTWARE.
 """
 
 
-from typing import TYPE_CHECKING, Tuple, FrozenSet, Any
+from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar
 
 from pytato.distributed.partition import DistributedGraphPartition
 
@@ -40,6 +40,9 @@ if TYPE_CHECKING:
     import mpi4py.MPI
 
 
+T = TypeVar("T")
+
+
 # {{{ construct tag numbering
 
 def number_distributed_tags(
@@ -59,6 +62,10 @@ def number_distributed_tags(
 
         This is a potentially heavyweight MPI-collective operation on
         *mpi_communicator*.
+
+    .. note::
+
+        This function requires that symbolic tags are comparable.
     """
     tags = frozenset({
             recv.comm_tag
@@ -73,8 +80,8 @@ def number_distributed_tags(
     from mpi4py import MPI
 
     def set_union(
-            set_a: FrozenSet[Any], set_b: FrozenSet[Any],
-            mpi_data_type: MPI.Datatype) -> FrozenSet[str]:
+            set_a: FrozenSet[T], set_b: FrozenSet[T],
+            mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]:
         assert mpi_data_type is None
         assert isinstance(set_a, frozenset)
         assert isinstance(set_b, frozenset)
@@ -99,7 +106,7 @@ def number_distributed_tags(
         next_tag = base_tag
         assert isinstance(all_tags, frozenset)
 
-        for sym_tag in all_tags:
+        for sym_tag in sorted(all_tags):
             sym_tag_to_int_tag[sym_tag] = next_tag
             next_tag += 1
 
-- 
GitLab