From 558d08ac40c3702aa1793160de23e4a2ca0e107c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 4 Dec 2017 21:49:03 -0600
Subject: [PATCH] Add from_discr parameter to make_face_to_all_faces_embedding

---
 meshmode/discretization/connection/face.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/meshmode/discretization/connection/face.py b/meshmode/discretization/connection/face.py
index 6ff6d0ca..d2745225 100644
--- a/meshmode/discretization/connection/face.py
+++ b/meshmode/discretization/connection/face.py
@@ -388,7 +388,8 @@ def make_face_restriction(discr, group_factory, boundary_tag,
 
 # {{{ face -> all_faces connection
 
-def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
+def make_face_to_all_faces_embedding(faces_connection, all_faces_discr,
+        from_discr=None):
     """Return a
     :class:`meshmode.discretization.connection.DiscretizationConnection`
     connecting a discretization containing some faces of a discretization
@@ -405,11 +406,19 @@ def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
         :class:`meshmode.discretization.connection.FACE_RESTR_ALL`
         for the same volume discretization as the one from which
         *faces_discr* was obtained.
+    :arg from_discr: Allows substituting in a different origin
+        discretization for the returned connection. This discretization
+        must use the same mesh as ``faces_connection.to_discr``.
     """
 
     vol_discr = faces_connection.from_discr
     faces_discr = faces_connection.to_discr
 
+    if from_discr is None:
+        from_discr = faces_discr
+
+    assert from_discr.mesh is faces_discr.mesh
+
     per_face_groups = (
             len(vol_discr.groups) != len(faces_discr.groups))
 
@@ -434,7 +443,6 @@ def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
 
             nfaces = vol_grp.mesh_el_group.nfaces
             for iface in range(nfaces):
-                faces_grp = faces_discr.groups[i_faces_grp]
                 all_faces_grp = all_faces_discr.groups[i_faces_grp]
 
                 if per_face_groups:
@@ -444,7 +452,8 @@ def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
                             == nfaces)
 
                 assert np.array_equal(
-                        faces_grp.unit_nodes, all_faces_grp.unit_nodes)
+                        from_discr.groups[i_faces_grp].unit_nodes,
+                        all_faces_grp.unit_nodes)
 
                 # {{{ find src_batch
 
@@ -484,7 +493,7 @@ def make_face_to_all_faces_embedding(faces_connection, all_faces_discr):
                     i_faces_grp += 1
 
     return DirectDiscretizationConnection(
-            faces_connection.to_discr,
+            from_discr,
             all_faces_discr,
             groups,
             is_surjective=False)
-- 
GitLab