From 1ac6c876150331ab2376635d5377eeb59e4512ad Mon Sep 17 00:00:00 2001
From: ellis <eshoag2@illinois.edu>
Date: Sat, 25 Mar 2017 16:27:55 -0500
Subject: [PATCH] interpart_adj_groups is not a list of maps from partition
 numbers to InterPartitionAdj

---
 meshmode/mesh/__init__.py   | 42 +++++++++++++------------------------
 meshmode/mesh/processing.py | 19 +++++++++--------
 test/test_meshmode.py       | 12 +++++------
 3 files changed, 29 insertions(+), 44 deletions(-)

diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py
index 472e6c4f..7ebaf393 100644
--- a/meshmode/mesh/__init__.py
+++ b/meshmode/mesh/__init__.py
@@ -428,9 +428,6 @@ class InterPartitionAdj():
 
         ``element_faces[i]`` is the face of ``elements[i]`` that has a neighbor.
 
-    .. attribute:: part_indices
-        ``part_indices[i]`` gives the partition index of the neighboring face.
-
     .. attribute:: neighbors
 
         ``neighbors[i]`` gives the element number within the neighboring partiton
@@ -444,7 +441,6 @@ class InterPartitionAdj():
         ``neighbor_faces[i]`` gives face index within the neighboring partition
         of the face connected to ``elements[i]``
 
-    .. automethod:: add_connection
     .. automethod:: get_neighbor
 
     .. versionadded:: 2017.1
@@ -455,40 +451,21 @@ class InterPartitionAdj():
         self.element_faces = []
         self.neighbors = []
         self.neighbor_faces = []
-        self.part_indices = []
-
-    def add_connection(self, elem, face, part_idx, neighbor_elem, neighbor_face):
-        """
-        Adds a connection from ``elem`` and ``face`` within :class:`Mesh` to
-        ``neighbor_elem`` and ``neighbor_face`` of the neighboring partion
-        of type :class:`Mesh` given by `part_idx`.
-        :arg elem:
-        :arg face:
-        :arg part_idx:
-        :arg neighbor_elem:
-        :arg neighbor_face:
-        """
-        self.elements.append(elem)
-        self.element_faces.append(face)
-        self.part_indices.append(part_idx)
-        self.neighbors.append(neighbor_elem)
-        self.neighbor_faces.append(neighbor_face)
 
     def get_neighbor(self, elem, face):
         """
         :arg elem
         :arg face
-        :returns: A tuple ``(part_idx, neighbor_elem, neighbor_face)`` of
+        :returns: A tuple ``(neighbor_elem, neighbor_face)`` of
                     neighboring elements within another :class:`Mesh`.
-                    Or (-1, -1, -1) if the face does not have a neighbor.
+                    Or (-1, -1) if the face does not have a neighbor.
         """
         for idx in range(len(self.elements)):
             if elem == self.elements[idx] and face == self.element_faces[idx]:
-                return (self.part_indices[idx],
-                        self.neighbors[idx],
+                return (self.neighbors[idx],
                         self.neighbor_faces[idx])
         #raise RuntimeError("This face does not have a neighbor")
-        return (-1, -1, -1)
+        return (-1, -1)
 
 # }}}
 
@@ -620,6 +597,15 @@ class Mesh(Record):
         (Note that element groups are not necessarily contiguous like the figure
         may suggest.)
 
+    .. attribute:: interpart_adj_groups
+
+        A list of mappings from neighbor partition numbers to instances of
+        :class:`InterPartitionAdj`.
+
+        ``interpart_adj_gorups[igrp][ineighbor_part]`` gives
+        the set of facial adjacency relations between group *igrp*
+        and partition *ineighbor_part*.
+
     .. attribute:: boundary_tags
 
         A tuple of boundary tag identifiers. :class:`BTAG_ALL` and
@@ -872,7 +858,7 @@ class Mesh(Record):
             if elem < grp.nelements:
                 return igrp
             elem -= grp.nelements
-        raise RuntimeError("Could not find group with element %d" % elem)
+        raise RuntimeError("Could not find group with element %d." % elem)
 
     # Design experience: Try not to add too many global data structures to the
     # mesh. Let the element groups be responsible for that at the mesh level.
diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index c028c63f..27a6c56d 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -148,7 +148,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
         facial_adjacency_groups=None, boundary_tags=boundary_tags)
 
     from meshmode.mesh import InterPartitionAdj
-    adj_grps = [InterPartitionAdj() for _ in range(len(part_mesh.groups))]
+    adj_grps = [{} for _ in range(len(part_mesh.groups))]
 
     for igrp, grp in enumerate(part_mesh.groups):
         elem_base = grp.element_nr_base
@@ -173,7 +173,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
                                parent_facial_group.element_faces[idx] == face:
                         rank_neighbor = (parent_facial_group.neighbors[idx]
                                             + parent_elem_base)
-                        rank_neighbor_face = parent_facial_group.neighbor_faces[idx]
+                        n_face = parent_facial_group.neighbor_faces[idx]
 
                         n_part_num = part_per_element[rank_neighbor]
                         tags = tags & ~part_mesh.boundary_tag_bit(BTAG_ALL)
@@ -185,14 +185,15 @@ def partition_mesh(mesh, part_per_element, part_nr):
                         n_elem = np.count_nonzero(
                                     part_per_element[:rank_neighbor] == n_part_num)
 
+                        if n_part_num not in adj_grps[igrp]:
+                            adj_grps[igrp][n_part_num] = InterPartitionAdj()
+
                         # I cannot compute the group because the other
-                        # partitions have not been built yet.
-                        adj_grps[igrp].add_connection(
-                                            elem,
-                                            face,
-                                            n_part_num,
-                                            n_elem,
-                                            rank_neighbor_face)
+                        # partitions may not have been built yet.
+                        adj_grps[igrp][n_part_num].elements.append(elem)
+                        adj_grps[igrp][n_part_num].element_faces.append(face)
+                        adj_grps[igrp][n_part_num].neighbors.append(n_elem)
+                        adj_grps[igrp][n_part_num].neighbor_faces.append(n_face)
 
     connected_mesh = part_mesh.copy()
     connected_mesh.interpart_adj_groups = adj_grps
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 51fad45d..e2980d2d 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -145,27 +145,25 @@ def test_partition_mesh():
     for part_num in range(num_parts):
         part, part_to_global = new_meshes[part_num]
         for grp_num, f_groups in enumerate(part.facial_adjacency_groups):
-            adj = part.interpart_adj_groups[grp_num]
             f_grp = f_groups[None]
             elem_base = part.groups[grp_num].element_nr_base
             for idx, elem in enumerate(f_grp.elements):
                 tag = -f_grp.neighbors[idx]
                 assert tag >= 0
                 face = f_grp.element_faces[idx]
-                for n_part_num in range(num_parts):
+                for n_part_num, adj in part.interpart_adj_groups[grp_num].items():
                     n_part, n_part_to_global = new_meshes[n_part_num]
                     if tag & part.boundary_tag_bit(BTAG_PARTITION(n_part_num)) != 0:
                         num_tags[n_part_num] += 1
 
-                        (i, n_elem, n_face) = adj.get_neighbor(elem, face)
-                        assert i == n_part_num
+                        (n_elem, n_face) = adj.get_neighbor(elem, face)
                         n_grp_num = n_part.find_igrp(n_elem)
-                        n_adj = n_part.interpart_adj_groups[n_grp_num]
+                        n_adj = n_part.interpart_adj_groups[n_grp_num][part_num]
                         n_elem_base = n_part.groups[n_grp_num].element_nr_base
                         n_elem -= n_elem_base
-                        assert (part_num, elem + elem_base, face) ==\
+                        assert (elem + elem_base, face) ==\
                                             n_adj.get_neighbor(n_elem, n_face),\
-                                            "InterpartitionAdj is not consistent"
+                                            "InterPartitionAdj is not consistent"
 
                         n_part_to_global = new_meshes[n_part_num][1]
                         p_elem = part_to_global[elem + elem_base]
-- 
GitLab