From 5c1eb19a2ec9f8c1328cd23b30e4716e481cf64d Mon Sep 17 00:00:00 2001
From: ellis <eshoag2@illinois.edu>
Date: Thu, 9 Mar 2017 00:15:10 -0600
Subject: [PATCH] Add documentation and fix bugs.

---
 meshmode/mesh/__init__.py   | 91 +++++++++++++++++++++++++++++++++++--
 meshmode/mesh/processing.py | 28 +++++++-----
 test/test_meshmode.py       | 65 ++++++++++++++++++--------
 3 files changed, 148 insertions(+), 36 deletions(-)

diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py
index 3673158a..0a01529d 100644
--- a/meshmode/mesh/__init__.py
+++ b/meshmode/mesh/__init__.py
@@ -88,6 +88,35 @@ class BTAG_NO_BOUNDARY(object):  # noqa
     pass
 
 
+class BTAG_PARTITION(object):
+    """
+    A boundary tag indicating that this edge is adjacent to an element of
+    another :class:`Mesh`. The partition number of the adjacent mesh
+    is given by ``part_nr``.
+    
+    .. attribute:: part_nr
+
+    .. versionadded:: 2017.1
+    """
+    def __init__(self, part_nr):
+        self.part_nr = int(part_nr)
+
+    # TODO is this acceptable?
+    # __eq__ is also defined so maybe the hash value isn't too important
+    # for dictionaries.
+    def __hash__(self):
+        return self.part_nr
+
+    def __eq__(self, other):
+        if isinstance(other, BTAG_PARTITION):
+            return self.part_nr == other.part_nr
+        else:
+            return False
+
+    def __nq__(self, other):
+        return not self.__eq__(other)
+
+
 SYSTEM_TAGS = set([BTAG_NONE, BTAG_ALL, BTAG_REALLY_ALL, BTAG_NO_BOUNDARY])
 
 # }}}
@@ -386,17 +415,66 @@ class NodalAdjacency(Record):
 
 class InterPartitionAdj():
     """
-    Interface is not final.
+    Describes facial adjacency information of elements in one :class:`Mesh` to
+    elements in another :class:`Mesh`. The element's boundary tag gives the 
+    partition that it is connected to.
+
+    .. attribute:: elements
+
+        `:class:Mesh`-local element numbers that have neighbors.
+    
+    .. attribute:: element_faces
+
+        ``element_faces[i]`` is the face of ``elements[i]`` that has a neighbor.
+
+    .. attribute:: neighbors
+
+        ``neighbors[i]`` gives the element number within the neighboring partiton 
+        of the element connected to ``elements[i]``.
+
+    .. attribute:: neighbor_faces
+
+        ``neighbor_faces[i]`` gives face index within the neighboring partition
+        of the face connected to ``elements[i]``
+    
+    .. automethod:: add_connection
+    .. automethod:: get_neighbor
+
+    .. versionadded:: 2017.1
     """
 
     def __init__(self):
-        self.adjacent = dict()
+        self.elements = []
+        self.element_faces = []
+        self.neighbors = []
+        self.neighbor_faces = []
 
     def add_connection(self, elem, face, neighbor_elem, neighbor_face):
-        self.adjacent[(elem, face)] = (neighbor_elem, neighbor_face)
+        """
+        Adds a connection from ``elem`` and ``face`` within :class:`Mesh` to 
+        ``neighbor_elem`` and ``neighbor_face`` of another neighboring partion
+        of type :class:`Mesh`.
+        :arg elem
+        :arg face
+        :arg neighbor_elem
+        :arg neighbor_face
+        """
+        self.elements.append(elem)
+        self.element_faces.append(face)
+        self.neighbors.append(neighbor_elem)
+        self.neighbor_faces.append(neighbor_face)
 
     def get_neighbor(self, elem, face):
-        return self.adjacent[(elem, face)]
+        """
+        :arg elem
+        :arg face
+        :returns: A tuple ``(neighbor_elem, neighbor_face)`` of neighboring
+                elements within another :class:`Mesh`.
+        """
+        for idx in range(len(self.elements)):
+            if elem == self.elements[idx] and face == self.element_faces[idx]:
+                return (self.neighbors[idx], self.neighbor_faces[idx])
+        raise RuntimeError("This face does not have a neighbor")
 
 # }}}
 
@@ -526,6 +604,7 @@ class Mesh(Record):
             node_vertex_consistency_tolerance=None,
             nodal_adjacency=False,
             facial_adjacency_groups=False,
+            interpartition_adj=False,
             boundary_tags=None,
             vertex_id_dtype=np.int32,
             element_id_dtype=np.int32):
@@ -607,6 +686,7 @@ class Mesh(Record):
                 self, vertices=vertices, groups=new_groups,
                 _nodal_adjacency=nodal_adjacency,
                 _facial_adjacency_groups=facial_adjacency_groups,
+                interpartition_adj=interpartition_adj,
                 boundary_tags=boundary_tags,
                 btag_to_index=btag_to_index,
                 vertex_id_dtype=np.dtype(vertex_id_dtype),
@@ -716,6 +796,7 @@ class Mesh(Record):
                         == other._nodal_adjacency)
                 and (self._facial_adjacency_groups
                         == other._facial_adjacency_groups)
+                and self.interpartition_adj == other.interpartition_adj
                 and self.boundary_tags == other.boundary_tags)
 
     def __ne__(self, other):
@@ -885,6 +966,7 @@ def _compute_facial_adjacency_from_vertices(mesh):
 
         for ineighbor_group in range(len(mesh.groups)):
             nb_count = group_count.get((igroup, ineighbor_group))
+            # FIXME nb_count is None sometimes when it maybe shouldn't be.
             if nb_count is not None:
                 elements = np.empty(nb_count, dtype=mesh.element_id_dtype)
                 element_faces = np.empty(nb_count, dtype=mesh.face_id_dtype)
@@ -912,6 +994,7 @@ def _compute_facial_adjacency_from_vertices(mesh):
                 idx = fill_count.get((igrp, inb_grp), 0)
                 fill_count[igrp, inb_grp] = idx + 1
 
+                # FIXME KeyError with inb_grp sometimes.
                 fagrp = facial_adjacency_groups[igroup][inb_grp]
                 fagrp.elements[idx] = iel
                 fagrp.element_faces[idx] = iface
diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index b1fc9e25..dba72d9d 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -138,21 +138,19 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 type(mesh_group)(mesh_group.order, new_indices[group_nr],
                     new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
 
-    num_parts = np.max(part_per_element)
-    boundary_tags = list(range(num_parts))
+    from meshmode.mesh import BTAG_ALL, BTAG_PARTITION
+    boundary_tags = [BTAG_PARTITION(n) for n in range(np.max(part_per_element))]
 
     from meshmode.mesh import Mesh
     part_mesh = Mesh(new_vertices, new_mesh_groups,
         facial_adjacency_groups=None, boundary_tags=boundary_tags)
 
-    from meshmode.mesh import BTAG_ALL
-
-    #TODO This should probably be in the Mesh class.
+    # FIXME I get errors when I try to copy part_mesh.
     from meshmode.mesh import InterPartitionAdj
     part_mesh.interpartition_adj = InterPartitionAdj()
 
     for igrp in range(num_groups):
-        part_group = part_mesh.groups[igrp]
+        elem_base = part_mesh.groups[igrp].element_nr_base
         boundary_adj = part_mesh.facial_adjacency_groups[igrp][None]
         boundary_elems = boundary_adj.elements
         boundary_faces = boundary_adj.element_faces
@@ -160,7 +158,6 @@ def partition_mesh(mesh, part_per_element, part_nr):
             elem = boundary_elems[elem_idx]
             face = boundary_faces[elem_idx]
             tags = -boundary_adj.neighbors[elem_idx]
-            # Is is reasonable to expect this assertation?
             assert tags >= 0, "Expected boundary tag in adjacency group."
             parent_elem = queried_elems[elem]
             parent_group_num = 0
@@ -169,8 +166,8 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 parent_group_num += 1
             assert parent_group_num < num_groups, "Unable to find neighbor."
             parent_grp_elem_base = mesh.groups[parent_group_num].element_nr_base
-            parent_boundary_adj = mesh.facial_adjacency_groups[parent_group_num]
-            for _, parent_facial_group in parent_boundary_adj.items():
+            parent_adj = mesh.facial_adjacency_groups[parent_group_num]
+            for _, parent_facial_group in parent_adj.items():
                 for idx in np.where(parent_facial_group.elements == parent_elem)[0]:
                     if parent_facial_group.neighbors[idx] >= 0 and \
                             parent_facial_group.element_faces[idx] == face:
@@ -180,7 +177,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
 
                         n_part_nr = part_per_element[rank_neighbor]
                         tags = tags & ~part_mesh.boundary_tag_bit(BTAG_ALL)
-                        tags = tags | part_mesh.boundary_tag_bit(n_part_nr)
+                        tags = tags | part_mesh.boundary_tag_bit(BTAG_PARTITION(n_part_nr))
                         boundary_adj.neighbors[elem_idx] = -tags
 
                         # Find the neighbor element from the other partition
@@ -190,7 +187,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
                         # TODO Test if this works with multiple groups
                         # Do I need to add the element number base?
                         part_mesh.interpartition_adj.add_connection(
-                            elem + part_group.element_nr_base,
+                            elem + elem_base,
                             face,
                             n_elem,
                             rank_neighbor_face)
@@ -425,10 +422,13 @@ def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False):
         order = None
         unit_nodes = None
         nodal_adjacency = None
+        facial_adjacency_groups = None
 
         for mesh in meshes:
             if mesh._nodal_adjacency is not None:
                 nodal_adjacency = False
+            if mesh._facial_adjacency_groups is not None:
+                facial_adjacency_groups = False
 
             for group in mesh.groups:
                 if grp_cls is None:
@@ -455,10 +455,13 @@ def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False):
     else:
         new_groups = []
         nodal_adjacency = None
+        facial_adjacency_groups = None
 
         for mesh, vert_base in zip(meshes, vert_bases):
             if mesh._nodal_adjacency is not None:
                 nodal_adjacency = False
+            if mesh._facial_adjacency_groups is not None:
+                facial_adjacency_groups = False
 
             for group in mesh.groups:
                 new_vertex_indices = group.vertex_indices + vert_base
@@ -469,7 +472,8 @@ def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False):
 
     from meshmode.mesh import Mesh
     return Mesh(vertices, new_groups, skip_tests=skip_tests,
-            nodal_adjacency=nodal_adjacency)
+            nodal_adjacency=nodal_adjacency,
+            facial_adjacency_groups=facial_adjacency_groups)
 
 # }}}
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index e94a3512..c0a697b9 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -92,44 +92,69 @@ def test_partition_boxes_mesh():
 
     from meshmode.mesh.processing import partition_mesh
     new_meshes = [
-        partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)]
+        partition_mesh(mesh, part_per_element, i) for i in range(num_parts)]
 
     assert mesh.nelements == np.sum(
-        [new_meshes[i].nelements for i in range(num_parts)]), \
+        [new_meshes[i][0].nelements for i in range(num_parts)]), \
         "part_mesh has the wrong number of elements"
 
-    assert count_btag_all(mesh) == np.sum(
-        [count_btag_all(new_meshes[i]) for i in range(num_parts)]), \
+    assert count_tags(mesh, BTAG_ALL) == np.sum(
+        [count_tags(new_meshes[i][0], BTAG_ALL) for i in range(num_parts)]), \
         "part_mesh has the wrong number of BTAG_ALL boundaries"
 
+    from meshmode.mesh import BTAG_PARTITION
+    num_tags = np.zeros((num_parts,))
+
     for part_nr in range(num_parts):
-        for f_groups in new_meshes[part_nr].facial_adjacency_groups:
+        (part, part_to_global) = new_meshes[part_nr]
+        for f_groups in part.facial_adjacency_groups:
             f_grp = f_groups[None]
             for idx in range(len(f_grp.elements)):
-                # Are all f_grp.neighbors guaranteed to be negative
-                # since I'm taking the boundary facial group?
                 tag = -f_grp.neighbors[idx]
+                assert tag >= 0
                 elem = f_grp.elements[idx]
                 face = f_grp.element_faces[idx]
                 for n_part_nr in range(num_parts):
-                    # Is tag >= 0 always true?
-                    if tag & new_meshes[part_nr].boundary_tag_bit(n_part_nr) != 0:
-                        # Is this the best way to probe the tag?
-                        # Can one tag have multiple partition neighbors?
-                        (n_elem, n_face) = new_meshes[part_nr].\
-                                interpartition_adj.get_neighbor(elem, face)
-                        assert (elem, face) == new_meshes[n_part_nr].\
-                                interpartition_adj.get_neighbor(n_elem, n_face),\
-                                "InterpartitionAdj is not consistent"
-
-
-def count_btag_all(mesh):
+                    (n_part, n_part_to_global) = new_meshes[n_part_nr]
+                    if tag & part.boundary_tag_bit(BTAG_PARTITION(n_part_nr)) != 0:
+                        num_tags[n_part_nr] += 1
+                        (n_elem, n_face) = part.interpartition_adj.\
+                                            get_neighbor(elem, face)
+                        assert (elem, face) == n_part.interpartition_adj.\
+                                            get_neighbor(n_elem, n_face),\
+                                            "InterpartitionAdj is not consistent"
+                        p_elem = part_to_global[elem]
+                        n_part_to_global = new_meshes[n_part_nr][1]
+                        p_n_elem = n_part_to_global[n_elem]
+                        p_grp_nr = 0
+                        while p_elem >= mesh.groups[p_grp_nr].nelements:
+                            p_elem -= mesh.groups[p_grp_nr].nelements
+                            p_grp_nr += 1
+                        p_elem_base = mesh.groups[p_grp_nr].element_nr_base
+                        f_groups = mesh.facial_adjacency_groups[p_grp_nr]
+                        for _, p_bnd_adj in f_groups.items():
+                            for idx in range(len(p_bnd_adj.elements)):
+                                if (p_elem == p_bnd_adj.elements[idx] and
+                                         face == p_bnd_adj.element_faces[idx]):
+                                    assert p_n_elem == p_bnd_adj.neighbors[idx],\
+                                            "Tag does not give correct neighbor"
+                                    assert n_face == p_bnd_adj.neighbor_faces[idx],\
+                                            "Tag does not give correct neighbor"
+
+    for tag_nr in range(num_parts):
+        tag_sum = 0
+        for mesh, _ in new_meshes:
+            tag_sum += count_tags(mesh, BTAG_PARTITION(tag_nr))
+        assert num_tags[tag_nr] == tag_sum,\
+                "part_mesh has the wrong number of BTAG_PARTITION boundaries"
+
+def count_tags(mesh, tag):
     num_bnds = 0
     for adj_dict in mesh.facial_adjacency_groups:
         for _, bdry_group in adj_dict.items():
             for neighbors in bdry_group.neighbors:
                 if neighbors < 0:
-                    if -neighbors & mesh.boundary_tag_bit(BTAG_ALL) != 0:
+                    if -neighbors & mesh.boundary_tag_bit(tag) != 0:
                         num_bnds += 1
     return num_bnds
 
-- 
GitLab