From adaca2cf838c9e50613a5665ccbfb864f66963fd Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 20 Oct 2015 22:37:42 -0500
Subject: [PATCH] make_face_restriction: generalize boundary restrictions to
 understand boundary tags and interior faces

---
 meshmode/discretization/connection.py | 117 ++++++++++++++++++--------
 test/test_meshmode.py                 |  27 +++---
 2 files changed, 99 insertions(+), 45 deletions(-)

diff --git a/meshmode/discretization/connection.py b/meshmode/discretization/connection.py
index 2b5eadeb..0eb5568a 100644
--- a/meshmode/discretization/connection.py
+++ b/meshmode/discretization/connection.py
@@ -41,7 +41,7 @@ __doc__ = """
 
 .. autofunction:: make_same_mesh_connection
 
-.. autofunction:: make_boundary_restriction
+.. autofunction:: make_face_restriction
 
 Implementation details
 ^^^^^^^^^^^^^^^^^^^^^^
@@ -318,66 +318,115 @@ def _build_boundary_connection(queue, vol_discr, bdry_discr, connection_data):
             vol_discr, bdry_discr, connection_groups)
 
 
-def make_boundary_restriction(queue, discr, group_factory):
-    """
-    :return: a tuple ``(bdry_mesh, bdry_discr, connection)``
-    """
+def _get_face_vertices(mesh, boundary_tag):
+    # a set of volume vertex numbers
+    bdry_vertex_vol_nrs = set()
 
-    logger.info("building boundary connection: start")
+    # {{{ pull together boundary vertices
 
-    # {{{ build face_map
+    if boundary_tag is not None:
+        # {{{ boundary faces
 
-    # maps (igrp, el_grp, face_id) to a frozenset of vertex IDs
-    face_map = {}
+        btag_bit = mesh.boundary_tag_bit(boundary_tag)
 
-    for igrp, mgrp in enumerate(discr.mesh.groups):
-        grp_face_vertex_indices = mgrp.face_vertex_indices()
+        for fagrp_map in mesh.facial_adjacency_groups:
+            bdry_grp = fagrp_map.get(None)
+            if bdry_grp is None:
+                continue
 
-        for iel_grp in range(mgrp.nelements):
-            for fid, loc_face_vertices in enumerate(grp_face_vertex_indices):
-                face_vertices = frozenset(
-                        mgrp.vertex_indices[iel_grp, fvi]
-                        for fvi in loc_face_vertices
-                        )
-                face_map.setdefault(face_vertices, []).append(
-                        (igrp, iel_grp, fid))
+            assert (bdry_grp.neighbors < 0).all()
 
-    del face_vertices
+            grp = mesh.groups[bdry_grp.igroup]
 
-    # }}}
+            nb_el_bits = -bdry_grp.neighbors
+            face_relevant_flags = (nb_el_bits & btag_bit) != 0
+
+            for iface, fvi in enumerate(grp.face_vertex_indices()):
+                bdry_vertex_vol_nrs.update(
+                        grp.vertex_indices
+                        [bdry_grp.elements[face_relevant_flags]]
+                        [:, np.array(fvi, dtype=np.intp)]
+                        .flat)
+
+        return np.array(sorted(bdry_vertex_vol_nrs), dtype=np.intp)
+
+        # }}}
+    else:
+        # For interior faces, this is likely every vertex in the book.
+        # Don't ever bother trying to cut the list down.
+
+        return np.arange(mesh.nvertices, dtype=np.intp)
+
+
+def make_face_restriction(queue, discr, group_factory, boundary_tag):
+    """Create a mesh, a discretization and a connection to restrict
+    a function on *discr* to its values on the edges of element faces
+    denoted by *boundary_tag*.
+
+    :arg boundary_tag: The boundary tag for which to create a face
+        restriction. May be *None* to indicate interior faces.
+
+    :return: a tuple ``(bdry_mesh, bdry_discr, connection)``
+    """
+
+    logger.info("building face restriction: start")
 
-    boundary_faces = [
-            face_ids[0]
-            for face_vertices, face_ids in six.iteritems(face_map)
-            if len(face_ids) == 1]
+    # {{{ gather boundary vertices
 
-    from pytools import flatten
-    bdry_vertex_vol_nrs = sorted(set(flatten(six.iterkeys(face_map))))
+    bdry_vertex_vol_nrs = _get_face_vertices(discr.mesh, boundary_tag)
 
     vol_to_bdry_vertices = np.empty(
             discr.mesh.vertices.shape[-1],
             discr.mesh.vertices.dtype)
     vol_to_bdry_vertices.fill(-1)
     vol_to_bdry_vertices[bdry_vertex_vol_nrs] = np.arange(
-            len(bdry_vertex_vol_nrs))
+            len(bdry_vertex_vol_nrs), dtype=np.intp)
 
     bdry_vertices = discr.mesh.vertices[:, bdry_vertex_vol_nrs]
 
+    # }}}
+
     from meshmode.mesh import Mesh, SimplexElementGroup
     bdry_mesh_groups = []
     connection_data = {}
 
-    for igrp, grp in enumerate(discr.groups):
+    btag_bit = discr.mesh.boundary_tag_bit(boundary_tag)
+
+    for igrp, (grp, fagrp_map) in enumerate(
+            zip(discr.groups, discr.mesh.facial_adjacency_groups)):
+
         mgrp = grp.mesh_el_group
-        group_boundary_faces = [
-                (ibface_el, ibface_face)
-                for ibface_group, ibface_el, ibface_face in boundary_faces
-                if ibface_group == igrp]
 
         if not isinstance(mgrp, SimplexElementGroup):
             raise NotImplementedError("can only take boundary of "
                     "SimplexElementGroup-based meshes")
 
+        # {{{ pull together per-group face lists
+
+        group_boundary_faces = []
+
+        if boundary_tag is not None:
+            bdry_grp = fagrp_map.get(None)
+            if bdry_grp is not None:
+                nb_el_bits = -bdry_grp.neighbors
+                face_relevant_flags = (nb_el_bits & btag_bit) != 0
+
+                group_boundary_faces.extend(
+                            zip(
+                                bdry_grp.elements[face_relevant_flags],
+                                bdry_grp.element_faces[face_relevant_flags]))
+
+        else:
+            for fagrp in six.itervalues(fagrp_map):
+                if fagrp.ineighbor_group is None:
+                    # boundary faces -> not looking for those
+                    continue
+
+                group_boundary_faces.extend(
+                        zip(bdry_grp.elements, bdry_grp.element_faces))
+
+        # }}}
+
         # {{{ Preallocate arrays for mesh group
 
         ngroup_bdry_elements = len(group_boundary_faces)
@@ -479,7 +528,7 @@ def make_boundary_restriction(queue, discr, group_factory):
     connection = _build_boundary_connection(
             queue, discr, bdry_discr, connection_data)
 
-    logger.info("building boundary connection: done")
+    logger.info("building face restriction: done")
 
     return bdry_mesh, bdry_discr, connection
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 2bd67300..55602e0f 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -67,10 +67,11 @@ def test_boundary_interpolation(ctx_getter):
     queue = cl.CommandQueue(cl_ctx)
 
     from meshmode.mesh.io import generate_gmsh, FileSource
+    from meshmode.mesh import BTAG_ALL
     from meshmode.discretization import Discretization
     from meshmode.discretization.poly_element import \
             InterpolatoryQuadratureSimplexGroupFactory
-    from meshmode.discretization.connection import make_boundary_restriction
+    from meshmode.discretization.connection import make_face_restriction
 
     from pytools.convergence import EOCRecorder
     eoc_rec = EOCRecorder()
@@ -95,8 +96,9 @@ def test_boundary_interpolation(ctx_getter):
         x = vol_discr.nodes()[0].with_queue(queue)
         f = 0.1*cl.clmath.sin(30*x)
 
-        bdry_mesh, bdry_discr, bdry_connection = make_boundary_restriction(
-                queue, vol_discr, InterpolatoryQuadratureSimplexGroupFactory(order))
+        bdry_mesh, bdry_discr, bdry_connection = make_face_restriction(
+                queue, vol_discr, InterpolatoryQuadratureSimplexGroupFactory(order),
+                BTAG_ALL)
 
         bdry_x = bdry_discr.nodes()[0].with_queue(queue)
         bdry_f = 0.1*cl.clmath.sin(30*bdry_x)
@@ -189,14 +191,14 @@ def test_sanity_single_element(ctx_getter, dim, order, visualize=False):
     center.fill(-0.5)
 
     import modepy as mp
-    from meshmode.mesh import SimplexElementGroup, Mesh
+    from meshmode.mesh import SimplexElementGroup, Mesh, BTAG_ALL
     mg = SimplexElementGroup(
             order=order,
             vertex_indices=np.arange(dim+1, dtype=np.int32).reshape(1, -1),
             nodes=mp.warp_and_blend_nodes(dim, order).reshape(dim, 1, -1),
             dim=dim)
 
-    mesh = Mesh(vertices, [mg])
+    mesh = Mesh(vertices, [mg], nodal_adjacency=None, facial_adjacency_groups=None)
 
     from meshmode.discretization import Discretization
     from meshmode.discretization.poly_element import \
@@ -224,9 +226,10 @@ def test_sanity_single_element(ctx_getter, dim, order, visualize=False):
 
     # {{{ boundary discretization
 
-    from meshmode.discretization.connection import make_boundary_restriction
-    bdry_mesh, bdry_discr, bdry_connection = make_boundary_restriction(
-            queue, vol_discr, PolynomialWarpAndBlendGroupFactory(order + 3))
+    from meshmode.discretization.connection import make_face_restriction
+    bdry_mesh, bdry_discr, bdry_connection = make_face_restriction(
+            queue, vol_discr, PolynomialWarpAndBlendGroupFactory(order + 3),
+            BTAG_ALL)
 
     # }}}
 
@@ -281,6 +284,7 @@ def test_sanity_balls(ctx_getter, src_file, dim, mesh_order,
     from pytential import bind, sym
 
     for h in [0.2, 0.14, 0.1]:
+        from meshmode.mesh import BTAG_ALL
         from meshmode.mesh.io import generate_gmsh, FileSource
         mesh = generate_gmsh(
                 FileSource(src_file), dim, order=mesh_order,
@@ -297,10 +301,11 @@ def test_sanity_balls(ctx_getter, src_file, dim, mesh_order,
         vol_discr = Discretization(ctx, mesh,
                 InterpolatoryQuadratureSimplexGroupFactory(quad_order))
 
-        from meshmode.discretization.connection import make_boundary_restriction
-        bdry_mesh, bdry_discr, bdry_connection = make_boundary_restriction(
+        from meshmode.discretization.connection import make_face_restriction
+        bdry_mesh, bdry_discr, bdry_connection = make_face_restriction(
                 queue, vol_discr,
-                InterpolatoryQuadratureSimplexGroupFactory(quad_order))
+                InterpolatoryQuadratureSimplexGroupFactory(quad_order),
+                BTAG_ALL)
 
         # }}}
 
-- 
GitLab