From daaef0c6b4637144c9405c09f493509f213e6965 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 4 Jan 2016 12:18:25 -0600
Subject: [PATCH] Implement face -> all_faces connection

---
 .../discretization/connection/__init__.py     |   4 +-
 meshmode/discretization/connection/face.py    | 105 ++++++++++++++++
 test/test_meshmode.py                         | 113 ++++++++++++++++++
 3 files changed, 221 insertions(+), 1 deletion(-)

diff --git a/meshmode/discretization/connection/__init__.py b/meshmode/discretization/connection/__init__.py
index 4cad16c9..c4bcf4a5 100644
--- a/meshmode/discretization/connection/__init__.py
+++ b/meshmode/discretization/connection/__init__.py
@@ -33,7 +33,7 @@ from meshmode.discretization.connection.same_mesh import \
         make_same_mesh_connection
 from meshmode.discretization.connection.face import (
         FRESTR_INTERIOR_FACES, FRESTR_ALL_FACES,
-        make_face_restriction)
+        make_face_restriction, make_face_to_all_faces_embedding)
 from meshmode.discretization.connection.opposite_face import \
         make_opposite_face_connection
 
@@ -46,6 +46,7 @@ __all__ = [
         "make_same_mesh_connection",
         "FRESTR_INTERIOR_FACES", "FRESTR_ALL_FACES",
         "make_face_restriction",
+        "make_face_to_all_faces_embedding",
         "make_opposite_face_connection"
         ]
 
@@ -57,6 +58,7 @@ __doc__ = """
 .. autofunction:: FRESTR_INTERIOR_FACES
 .. autofunction:: FRESTR_ALL_FACES
 .. autofunction:: make_face_restriction
+.. autofunction:: make_face_to_all_faces_embedding
 
 .. autofunction:: make_opposite_face_connection
 
diff --git a/meshmode/discretization/connection/face.py b/meshmode/discretization/connection/face.py
index 14127b32..7c7e7dde 100644
--- a/meshmode/discretization/connection/face.py
+++ b/meshmode/discretization/connection/face.py
@@ -379,4 +379,109 @@ def make_face_restriction(discr, group_factory, boundary_tag,
 
 # }}}
 
+
+# {{{ face -> all_faces connection
+
+def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
+    """Return a
+    :class:`meshmode.discretization.connection.DiscretizationConnection`
+    connecting a discretization containing some faces of a discretization
+    to one containing all faces.
+
+    :arg faces_connection: must be the (connection) result of calling
+        :func:`meshmode.discretization.connection.make_face_restriction`
+        with
+        :class:`meshmode.discretization.connection.FRESTR_INTERIOR_FACES`
+        or a boundary tag.
+    :arg all_faces_discr: must be the (discretization) result of calling
+        :func:`meshmode.discretization.connection.make_face_restriction`
+        with
+        :class:`meshmode.discretization.connection.FRESTR_ALL_FACES`
+        for the same volume discretization as the one from which
+        *faces_discr* was obtained.
+    """
+
+    vol_discr = faces_connection.from_discr
+    faces_discr = faces_connection.to_discr
+
+    per_face_groups = (
+            len(vol_discr.groups) != len(faces_discr.groups))
+
+    if len(faces_discr.groups) != len(all_faces_discr.groups):
+        raise ValueError("faces_discr and all_faces_discr must have the "
+                "same number of groups")
+    if len(faces_connection.groups) != len(all_faces_discr.groups):
+        raise ValueError("faces_connection and all_faces_discr must have the "
+                "same number of groups")
+
+    from meshmode.discretization.connection import (
+            DiscretizationConnection,
+            DiscretizationConnectionElementGroup,
+            InterpolationBatch)
+
+    i_faces_grp = 0
+
+    with cl.CommandQueue(vol_discr.cl_context) as queue:
+        groups = []
+        for ivol_grp, vol_grp in enumerate(vol_discr.groups):
+            batches = []
+
+            nfaces = vol_grp.mesh_el_group.nfaces
+            for iface in range(nfaces):
+                faces_grp = faces_discr.groups[i_faces_grp]
+                all_faces_grp = all_faces_discr.groups[i_faces_grp]
+
+                if per_face_groups:
+                    assert len(faces_connection.groups[i_faces_grp].batches) == 1
+                else:
+                    assert len(faces_connection.groups[i_faces_grp].batches) == nfaces
+
+                assert np.array_equal(
+                        faces_grp.unit_nodes, all_faces_grp.unit_nodes)
+
+                # {{{ find src_batch
+
+                src_batches = faces_connection.groups[i_faces_grp].batches
+                if per_face_groups:
+                    src_batch, = src_batches
+                else:
+                    src_batch = src_batches[iface]
+                del src_batches
+
+                # }}}
+
+                if per_face_groups:
+                    to_element_indices = src_batch.from_element_indices
+                else:
+                    assert all_faces_grp.nelements == nfaces * vol_grp.nelements
+
+                    to_element_indices = (
+                            vol_grp.nelements*iface
+                            + src_batch.from_element_indices.with_queue(queue)
+                            ).with_queue(None)
+
+                batches.append(
+                        InterpolationBatch(
+                            from_group_index=i_faces_grp,
+                            from_element_indices=src_batch.to_element_indices,
+                            to_element_indices=to_element_indices,
+                            result_unit_nodes=all_faces_grp.unit_nodes,
+                            to_element_face=None))
+
+                is_last_face = iface + 1 == nfaces
+                if per_face_groups or is_last_face:
+                    groups.append(
+                            DiscretizationConnectionElementGroup(batches=batches))
+                    batches = []
+
+                    i_faces_grp += 1
+
+    return DiscretizationConnection(
+            faces_connection.to_discr,
+            all_faces_discr,
+            groups,
+            is_surjective=False)
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index c169e284..97d38daf 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -168,6 +168,119 @@ def test_boundary_interpolation(ctx_getter, group_factory, boundary_tag,
             eoc_rec.order_estimate() >= order-0.5
             or eoc_rec.max_error() < 1e-14)
 
+# }}}
+
+
+# {{{ boundary-to-all-faces connecttion
+
+@pytest.mark.parametrize(("mesh_name", "dim", "mesh_pars"), [
+    ("blob", 2, [1e-1, 8e-2, 5e-2]),
+    ("warp", 2, [10, 20, 30]),
+    ("warp", 3, [10, 20, 30]),
+    ])
+@pytest.mark.parametrize("per_face_groups", [False, True])
+def test_all_faces_interpolation(ctx_getter, mesh_name, dim, mesh_pars,
+        per_face_groups):
+    cl_ctx = ctx_getter()
+    queue = cl.CommandQueue(cl_ctx)
+
+    from meshmode.discretization import Discretization
+    from meshmode.discretization.connection import (
+            make_face_restriction, make_face_to_all_faces_embedding,
+            check_connection)
+
+    from pytools.convergence import EOCRecorder
+    eoc_rec = EOCRecorder()
+
+    order = 4
+
+    def f(x):
+        return 0.1*cl.clmath.sin(30*x)
+
+    for mesh_par in mesh_pars:
+        # {{{ get mesh
+
+        if mesh_name == "blob":
+            assert dim == 2
+
+            h = mesh_par
+
+            from meshmode.mesh.io import generate_gmsh, FileSource
+            print("BEGIN GEN")
+            mesh = generate_gmsh(
+                    FileSource("blob-2d.step"), 2, order=order,
+                    force_ambient_dim=2,
+                    other_options=[
+                        "-string", "Mesh.CharacteristicLengthMax = %s;" % h]
+                    )
+            print("END GEN")
+        elif mesh_name == "warp":
+            from meshmode.mesh.generation import generate_warped_rect_mesh
+            mesh = generate_warped_rect_mesh(dim, order=4, n=mesh_par)
+
+            h = 1/mesh_par
+        else:
+            raise ValueError("mesh_name not recognized")
+
+        # }}}
+
+        vol_discr = Discretization(cl_ctx, mesh,
+                PolynomialWarpAndBlendGroupFactory(order))
+        print("h=%s -> %d elements" % (
+                h, sum(mgrp.nelements for mgrp in mesh.groups)))
+
+        all_face_bdry_connection = make_face_restriction(
+                vol_discr, PolynomialWarpAndBlendGroupFactory(order),
+                FRESTR_ALL_FACES, per_face_groups=per_face_groups)
+        all_face_bdry_discr = all_face_bdry_connection.to_discr
+
+        for ito_grp, ceg in enumerate(all_face_bdry_connection.groups):
+            for ibatch, batch in enumerate(ceg.batches):
+                assert np.array_equal(
+                        batch.from_element_indices.get(queue),
+                        np.arange(vol_discr.mesh.nelements))
+
+                if per_face_groups:
+                    assert ito_grp == batch.to_element_face
+                else:
+                    assert ibatch == batch.to_element_face
+
+        all_face_x = all_face_bdry_discr.nodes()[0].with_queue(queue)
+        all_face_f = f(all_face_x)
+
+        all_face_f_2 = all_face_bdry_discr.zeros(queue)
+
+        for boundary_tag in [
+                BTAG_ALL,
+                FRESTR_INTERIOR_FACES,
+                ]:
+            bdry_connection = make_face_restriction(
+                    vol_discr, PolynomialWarpAndBlendGroupFactory(order),
+                    boundary_tag, per_face_groups=per_face_groups)
+            bdry_discr = bdry_connection.to_discr
+
+            bdry_x = bdry_discr.nodes()[0].with_queue(queue)
+            bdry_f = f(bdry_x)
+
+            all_face_embedding = make_face_to_all_faces_embedding(
+                    bdry_connection, all_face_bdry_discr)
+
+            check_connection(all_face_embedding)
+
+            all_face_f_2 += all_face_embedding(queue, bdry_f)
+
+        err = la.norm((all_face_f-all_face_f_2).get(), np.inf)
+        eoc_rec.add_data_point(h, err)
+
+    print(eoc_rec)
+    assert (
+            eoc_rec.order_estimate() >= order-0.5
+            or eoc_rec.max_error() < 1e-14)
+
+# }}}
+
+
+# {{{ convergence of opposite-face interpolation
 
 @pytest.mark.parametrize("group_factory", [
     InterpolatoryQuadratureSimplexGroupFactory,
-- 
GitLab