From fcfb51104c09971d8fc329f6b665cbee9fca6a0e Mon Sep 17 00:00:00 2001
From: benSepanski <ben_sepanski@alumni.baylor.edu>
Date: Thu, 2 Jul 2020 10:41:22 -0500
Subject: [PATCH] Made FiredrakeConnection class from node mapping to avoid
 code duplication

---
 meshmode/interop/firedrake/connection.py | 447 +++++++++++++++--------
 test/test_firedrake_interop.py           |  30 +-
 2 files changed, 321 insertions(+), 156 deletions(-)

diff --git a/meshmode/interop/firedrake/connection.py b/meshmode/interop/firedrake/connection.py
index a73ec5ef..6a11c70e 100644
--- a/meshmode/interop/firedrake/connection.py
+++ b/meshmode/interop/firedrake/connection.py
@@ -21,6 +21,8 @@ THE SOFTWARE.
 """
 
 __doc__ = """
+.. autoclass:: FiredrakeConnection
+    :members:
 .. autoclass:: FromFiredrakeConnection
     :members:
 """
@@ -31,14 +33,16 @@ import six
 
 from modepy import resampling_matrix
 
-from meshmode.interop.firedrake.mesh import import_firedrake_mesh
+from meshmode.interop.firedrake.mesh import (
+    import_firedrake_mesh, export_mesh_to_firedrake)
 from meshmode.interop.firedrake.reference_cell import (
     get_affine_reference_simplex_mapping, get_finat_element_unit_nodes)
 
 from meshmode.mesh.processing import get_simplex_element_flip_matrix
 
 from meshmode.discretization.poly_element import \
-    InterpolatoryQuadratureSimplexGroupFactory
+    InterpolatoryQuadratureSimplexGroupFactory, \
+    InterpolatoryQuadratureSimplexElementGroup
 from meshmode.discretization import Discretization
 
 
@@ -74,52 +78,109 @@ def _reorder_nodes(orient, nodes, flip_matrix, unflip=False):
         flip_mat, nodes[orient < 0])
 
 
-# {{{ Create connection from firedrake into meshmode
-
+# {{{ Most basic connection between a fd function space and mm discretization
 
-class FromFiredrakeConnection:
+class FiredrakeConnection:
     """
-    A connection created from a :mod:`firedrake`
-    ``"CG"`` or ``"DG"`` function space which creates a corresponding
-    meshmode discretization and allows
-    transfer of functions to and from :mod:`firedrake`.
+    A connection between one group of
+    a meshmode discretization and a firedrake "CG" or "DG"
+    function space.
+
+    .. autoattribute:: discr
+
+        A meshmode discretization
 
-    .. attribute:: to_discr
+    .. autoattribute:: group_nr
 
-        The discretization corresponding to the firedrake function
-        space created with a
-        :class:`InterpolatoryQuadratureSimplexElementGroup`.
+        The group number identifying which element group of
+        :attr:`discr` is being connected to a firedrake function space
+
+    .. autoattribute:: mm2fd_node_mapping
+
+        A numpy array of shape *(self.discr.groups[group_nr].nnodes,)*
+        whose *i*th entry is the :mod:`firedrake` node index associated
+        to the *i*th node in *self.discr.groups[group_nr]*.
+        It is important to note that, due to :mod:`meshmode`
+        and :mod:`firedrake` using different unit nodes, a :mod:`firedrake`
+        node associated to a :mod:`meshmode` may have different coordinates.
+        However, after resampling to the other system's unit nodes,
+        two associated nodes should have identical coordinates.
     """
-    def __init__(self, cl_ctx, fdrake_fspace):
+    def __init__(self, discr, fdrake_fspace, mm2fd_node_mapping, group_nr=None):
         """
-        :arg cl_ctx: A :mod:`pyopencl` computing context
-        :arg fdrake_fspace: A :mod:`firedrake` ``"CG"`` or ``"DG"``
-            function space (of class :class:`WithGeometry`) built on
-            a mesh which is importable by :func:`import_firedrake_mesh`.
+        :param discr: A :mod:`meshmode` :class:`Discretization`
+        :param fdrake_fspace: A :mod:`firedrake`
+            :class:`firedrake.functionspaceimpl.WithGeometry`.
+            Must have ufl family ``'Lagrange'`` or
+            ``'Discontinuous Lagrange'``.
+        :param mm2fd_node_mapping: Used as attribute :attr:`mm2fd_node_mapping`.
+            A numpy integer array with the same dtype as
+            ``fdrake_fspace.cell_node_list.dtype``
+        :param group_nr: The index of the group in *discr* which is
+            being connected to *fdrake_fspace*. The group must be
+            a :class:`InterpolatoryQuadratureSimplexElementGroup`
+            of the same topological dimension as *fdrake_fspace*.
+            If *discr* has only one group, *group_nr=None* may be supplied.
+
+        :raises TypeError: If any input arguments are of the wrong type,
+            if the designated group is of the wrong type,
+            or if *fdrake_fspace* is of the wrong family.
+        :raises ValueError: If:
+            * *mm2fd_node_mapping* is of the wrong shape
+              or dtype, if *group_nr* is an invalid index
+            * If *group_nr* is *None* when *discr* has more than one group.
         """
-        # Ensure fdrake_fspace is a function space with appropriate reference
-        # element.
+        # {{{ Validate input
+        if not isinstance(discr, Discretization):
+            raise TypeError(":param:`discr` must be of type "
+                            ":class:`meshmode.discretization.Discretization`, "
+                            "not :class:`%s`." % type(discr))
         from firedrake.functionspaceimpl import WithGeometry
         if not isinstance(fdrake_fspace, WithGeometry):
-            raise TypeError(":arg:`fdrake_fspace` must be of firedrake type "
-                            ":class:`WithGeometry`, not `%s`."
-                            % type(fdrake_fspace))
-        ufl_elt = fdrake_fspace.ufl_element()
-
-        if ufl_elt.family() not in ('Lagrange', 'Discontinuous Lagrange'):
-            raise ValueError("the ``ufl_element().family()`` of "
-                             ":arg:`fdrake_fspace` must "
-                             "be ``'Lagrange'`` or "
-                             "``'Discontinuous Lagrange'``, not %s."
-                             % ufl_elt.family())
-
-        # Create to_discr
-        mm_mesh, orient = import_firedrake_mesh(fdrake_fspace.mesh())
-        factory = InterpolatoryQuadratureSimplexGroupFactory(ufl_elt.degree())
-        self.to_discr = Discretization(cl_ctx, mm_mesh, factory)
+            raise TypeError(":param:`fdrake_fspace` must be of type "
+                            ":class:`firedrake.functionspaceimpl.WithGeometry`, "
+                            "not :class:`%s`." % type(fdrake_fspace))
+        if not isinstance(mm2fd_node_mapping, np.ndarray):
+            raise TypeError(":param:`mm2fd_node_mapping` must be of type "
+                            ":class:`np.ndarray`, "
+                            "not :class:`%s`." % type(mm2fd_node_mapping))
+        if not isinstance(group_nr, int) and group_nr is not None:
+            raise TypeError(":param:`group_nr` must be of type *int* or be "
+                            "*None*, not of type %s." % type(group_nr))
+        # Convert group_nr to an integer if *None*
+        if group_nr is None:
+            if len(discr.groups) != 1:
+                raise ValueError(":param:`group_nr` is *None* but :param:`discr` "
+                                 "has %s != 1 groups." % len(discr.groups))
+            group_nr = 0
+        if group_nr < 0 or group_nr >= len(discr.groups):
+            raise ValueError(":param:`group_nr` has value %s, which an invalid "
+                             "index into list *discr.groups* of length %s."
+                             % (group_nr, len(discr.gropus)))
+        if not isinstance(discr.groups[group_nr],
+                          InterpolatoryQuadratureSimplexElementGroup):
+            raise TypeError("*discr.groups[group_nr]* must be of type "
+                            ":class:`InterpolatoryQuadratureSimplexElementGroup`"
+                            ", not :class:`%s`." % type(discr.groups[group_nr]))
+        allowed_families = ('Discontinuous Lagrange', 'Lagrange')
+        if fdrake_fspace.ufl_element().family() not in allowed_families:
+            raise TypeError(":param:`fdrake_fspace` must have ufl family "
+                           "be one of %s, not %s."
+                            % (allowed_families,
+                               fdrake_fspace.ufl_element().family()))
+        if mm2fd_node_mapping.shape != (discr.groups[group_nr].nnodes,):
+            raise ValueError(":param:`mm2fd_node_mapping` must be of shape ",
+                             "(%s,), not %s"
+                             % ((discr.groups[group_nr].nnodes,),
+                                mm2fd_node_mapping.shape))
+        if mm2fd_node_mapping.dtype != fdrake_fspace.cell_node_list.dtype:
+            raise ValueError(":param:`mm2fd_node_mapping` must have dtype "
+                             "%s, not %s" % (fdrake_fspace.cell_node_list.dtype,
+                                             mm2fd_node_mapping.dtype))
+        # }}}
 
         # Get meshmode unit nodes
-        element_grp = self.to_discr.groups[0]
+        element_grp = discr.groups[group_nr]
         mm_unit_nodes = element_grp.unit_nodes
         # get firedrake unit nodes and map onto meshmode reference element
         dim = fdrake_fspace.mesh().topological_dimension()
@@ -127,7 +188,7 @@ class FromFiredrakeConnection:
         fd_unit_nodes = get_finat_element_unit_nodes(fdrake_fspace.finat_element)
         fd_unit_nodes = fd_ref_cell_to_mm(fd_unit_nodes)
 
-        # compute resampling matrices
+        # compute and store resampling matrices
         self._resampling_mat_fd2mm = resampling_matrix(element_grp.basis(),
                                                        new_nodes=mm_unit_nodes,
                                                        old_nodes=fd_unit_nodes)
@@ -135,101 +196,98 @@ class FromFiredrakeConnection:
                                                        new_nodes=fd_unit_nodes,
                                                        old_nodes=mm_unit_nodes)
 
-        # Flipping negative elements corresponds to reordering the nodes.
-        # We handle reordering by storing the permutation explicitly as
-        # a numpy array
-
-        # Get the reordering fd->mm.
-        #
-        # One should note there is something a bit more subtle going on
-        # in the continuous case. All meshmode discretizations use
-        # are discontinuous, so nodes are associated with elements(cells)
-        # not vertices. In a continuous firedrake space, some nodes
-        # are shared between multiple cells. In particular, while the
-        # below "reordering" is indeed a permutation if the firedrake space
-        # is discontinuous, if the firedrake space is continuous then
-        # some firedrake nodes correspond to nodes on multiple meshmode
-        # elements, i.e. those nodes appear multiple times
-        # in the "reordering" array
-        flip_mat = get_simplex_element_flip_matrix(ufl_elt.degree(),
-                                                   fd_unit_nodes)
-        fd_cell_node_list = fdrake_fspace.cell_node_list
-        _reorder_nodes(orient, fd_cell_node_list, flip_mat, unflip=False)
-        self._reordering_arr_fd2mm = fd_cell_node_list.flatten()
-
-        # Now handle the possibility of duplicate nodes
-        unique_fd_nodes, counts = np.unique(self._reordering_arr_fd2mm,
+        # Now handle the possibility of multiple meshmode nodes being associated
+        # to the same firedrake node
+        unique_fd_nodes, counts = np.unique(mm2fd_node_mapping,
                                             return_counts=True)
         # self._duplicate_nodes
         # maps firedrake nodes associated to more than 1 meshmode node
         # to all associated meshmode nodes.
         self._duplicate_nodes = {}
-        if ufl_elt.family() == 'Discontinuous Lagrange':
-            assert np.all(counts == 1), \
-                "This error should never happen, some nodes in a firedrake " \
-                "discontinuous space were duplicated. Contact the developer "
-        else:
-            dup_fd_nodes = set(unique_fd_nodes[counts > 1])
-            for mm_inode, fd_inode in enumerate(self._reordering_arr_fd2mm):
-                if fd_inode in dup_fd_nodes:
-                    self._duplicate_nodes.setdefault(fd_inode, [])
-                    self._duplicate_nodes[fd_inode].append(mm_inode)
-
-        # Store things that we need for *from_fspace*
-        self._ufl_element = ufl_elt
+        dup_fd_nodes = set(unique_fd_nodes[counts > 1])
+        for mm_inode, fd_inode in enumerate(mm2fd_node_mapping):
+            if fd_inode in dup_fd_nodes:
+                self._duplicate_nodes.setdefault(fd_inode, [])
+                self._duplicate_nodes[fd_inode].append(mm_inode)
+
+        # Store input
+        self.discr = discr
+        self.group_nr = group_nr
+        self.mm2fd_node_mapping = mm2fd_node_mapping
         self._mesh_geometry = fdrake_fspace.mesh()
-        self._fspace_cache = {}  # map vector dim -> firedrake fspace
+        self._ufl_element = fdrake_fspace.ufl_element()
+        # Cache firedrake function spaces of each vector dimension to
+        # avoid overhead. Firedrake takes care of avoiding memory
+        # duplication.
+        self._fspace_cache = {}
 
-    def from_fspace(self, dim=None):
+    def firedrake_fspace(self, vdim=None):
         """
-        Return a firedrake function space of the appropriate vector dimension
-
-        :arg dim: Either *None*, in which case a function space which maps
-                    to scalar values is returned, or a positive integer *n*,
-                    in which case a function space which maps into *\\R^n*
-                    is returned
+        Return a firedrake function space that
+        *self.discr.groups[self.group_nr]* is connected to
+        of the appropriate vector dimension
+
+        :arg vdim: Either *None*, in which case a function space which maps
+                   to scalar values is returned, a positive integer *n*,
+                   in which case a function space which maps into *\\R^n*
+                   is returned, or a tuple of integers defining
+                   the shape of values in a tensor function space,
+                   in which case a tensor function space is returned
         :return: A :mod:`firedrake` :class:`WithGeometry` which corresponds to
-                 :attr:`to_discr` of the appropriate vector dimension
+                 *self.discr.groups[self.group_nr]* of the appropriate vector
+                 dimension
+
+        :raises TypeError: If *vdim* is of the wrong type
         """
         # Cache the function spaces created to avoid high overhead.
         # Note that firedrake is smart about re-using shared information,
         # so this is not duplicating mesh/reference element information
-        if dim not in self._fspace_cache:
-            assert (isinstance(dim, int) and dim > 0) or dim is None
-            if dim is None:
+        if vdim not in self._fspace_cache:
+            if vdim is None:
                 from firedrake import FunctionSpace
-                self._fspace_cache[dim] = \
+                self._fspace_cache[vdim] = \
                     FunctionSpace(self._mesh_geometry,
                                   self._ufl_element.family(),
                                   degree=self._ufl_element.degree())
-            else:
+            elif isinstance(vdim, int):
                 from firedrake import VectorFunctionSpace
-                self._fspace_cache[dim] = \
+                self._fspace_cache[vdim] = \
                     VectorFunctionSpace(self._mesh_geometry,
                                         self._ufl_element.family(),
                                         degree=self._ufl_element.degree(),
-                                        dim=dim)
-        return self._fspace_cache[dim]
+                                        dim=vdim)
+            elif isinstance(vdim, tuple):
+                from firedrake import TensorFunctionSpace
+                self._fspace_cache[vdim] = \
+                    TensorFunctionSpace(self._mesh_geometry,
+                                        self._ufl_element.family(),
+                                        degree=self._ufl_element.degree(),
+                                        shape=vdim)
+            else:
+                raise TypeError(":param:`vdim` must be *None*, an integer, "
+                                " or a tuple of integers, not of type %s."
+                                % type(vdim))
+        return self._fspace_cache[vdim]
 
     def from_firedrake(self, function, out=None):
         """
-        transport firedrake function onto :attr:`to_discr`
+        transport firedrake function onto :attr:`discr`
 
         :arg function: A :mod:`firedrake` function to transfer onto
-            :attr:`to_discr`. Its function space must have
+            :attr:`discr`. Its function space must have
             the same family, degree, and mesh as ``self.from_fspace()``.
         :arg out: If *None* then ignored, otherwise a numpy array of the
-            shape *function.dat.data.shape.T* (i.e.
-            *(dim, nnodes)* or *(nnodes,)* in which *function*'s
-            transported data is stored.
+            shape (i.e.
+            *(..., num meshmode nodes)* or *(num meshmode nodes,)* and of the
+            same dtype in which *function*'s transported data will be stored
 
         :return: a numpy array holding the transported function
         """
-        # make sure function is a firedrake function in an appropriate
-        # function space
+        # Check function is convertible
         from firedrake.function import Function
-        assert isinstance(function, Function), \
-            ":arg:`function` must be a :mod:`firedrake` Function"
+        if not isinstance(function, Function):
+            raise TypeError(":arg:`function` must be a :mod:`firedrake` "
+                            "Function")
         assert function.function_space().ufl_element().family() \
             == self._ufl_element.family() and \
             function.function_space().ufl_element().degree() \
@@ -240,22 +298,28 @@ class FromFiredrakeConnection:
             ":arg:`function` mesh must be the same mesh as used by " \
             "``self.from_fspace().mesh()``"
 
-        # Get function data as shape [nnodes][dims] or [nnodes]
+        # Get function data as shape *(nnodes, ...)* or *(nnodes,)*
         function_data = function.dat.data
 
-        if out is None:
-            shape = (self.to_discr.nnodes,)
-            if len(function_data.shape) > 1:
-                shape = (function_data.shape[1],) + shape
-            out = np.ndarray(shape, dtype=function_data.dtype)
-        # Reorder nodes
-        if len(out.shape) > 1:
-            out[:] = function_data.T[:, self._reordering_arr_fd2mm]
+        # Check that out is supplied correctly, or create out if it is
+        # not supplied
+        shape = (self.discr.groups[self.group_nr].nnodes,)
+        if len(function_data.shape) > 1:
+            shape = function_data.shape[1:] + shape
+        if out is not None:
+            if not isinstance(out, np.ndarray):
+                raise TypeError(":param:`out` must of type *np.ndarray* or "
+                                "be *None*")
+                assert out.shape == shape, \
+                    ":param:`out` must have shape %s." % shape
+                assert out.dtype == function.dat.data.dtype
         else:
-            out[:] = function_data[self._reordering_arr_fd2mm]
+            out = np.ndarray(shape, dtype=function_data.dtype)
 
+        # Reorder nodes
+        out[:] = np.moveaxis(function_data, 0, -1)[..., self.mm2fd_node_mapping]
         # Resample at the appropriate nodes
-        out_view = self.to_discr.groups[0].view(out)
+        out_view = self.discr.groups[self.group_nr].view(out)
         np.matmul(out_view, self._resampling_mat_fd2mm.T, out=out_view)
         return out
 
@@ -263,20 +327,20 @@ class FromFiredrakeConnection:
                       assert_fdrake_discontinuous=True,
                       continuity_tolerance=None):
         """
-        transport meshmode field from :attr:`to_discr` into an
+        transport meshmode field from :attr:`discr` into an
         appropriate firedrake function space.
 
-        :arg mm_field: A numpy array of shape *(nnodes,)* or *(dim, nnodes)*
+        :arg mm_field: A numpy array of shape *(nnodes,)* or *(..., nnodes)*
             representing a function on :attr:`to_distr`.
         :arg out: If *None* then ignored, otherwise a :mod:`firedrake`
             function of the right function space for the transported data
             to be stored in.
         :arg assert_fdrake_discontinuous: If *True*,
             disallows conversion to a continuous firedrake function space
-            (i.e. this function checks that ``self.from_fspace()`` is
+            (i.e. this function checks that ``self.firedrake_fspace()`` is
              discontinuous and raises a *ValueError* otherwise)
         :arg continuity_tolerance: If converting to a continuous firedrake
-            function space (i.e. if ``self.from_fspace()`` is continuous),
+            function space (i.e. if ``self.firedrake_fspace()`` is continuous),
             assert that at any two meshmode nodes corresponding to the
             same firedrake node (meshmode is a discontinuous space, so this
             situation will almost certainly happen), the function being transported
@@ -302,16 +366,18 @@ class FromFiredrakeConnection:
                 out.function_space().ufl_element().degree() \
                 == self._ufl_element.degree(), \
                 ":arg:`out` must live in a function space with the " \
-                "same family and degree as ``self.from_fspace()``"
+                "same family and degree as ``self.firedrake_fspace()``"
             assert out.function_space().mesh() is self._mesh_geometry, \
                 ":arg:`out` mesh must be the same mesh as used by " \
-                "``self.from_fspace().mesh()`` or *None*"
+                "``self.firedrake_fspace().mesh()`` or *None*"
         else:
             if len(mm_field.shape) == 1:
-                dim = None
+                vdim = None
+            elif len(mm_field.shape) == 2:
+                vdim = mm_field.shape[0]
             else:
-                dim = mm_field.shape[0]
-            out = Function(self.from_fspace(dim))
+                vdim = mm_field.shape[:-1]
+            out = Function(self.firedrake_fspace(vdim))
 
         # Handle 1-D case
         if len(out.dat.data.shape) == 1 and len(mm_field.shape) > 1:
@@ -320,17 +386,21 @@ class FromFiredrakeConnection:
         # resample from nodes on reordered view. Have to do this in
         # a bit of a roundabout way to be careful about duplicated
         # firedrake nodes.
-        by_cell_field_view = self.to_discr.groups[0].view(mm_field)
-        if len(out.dat.data.shape) == 1:
-            reordered_outdata = out.dat.data[self._reordering_arr_fd2mm]
-        else:
-            reordered_outdata = out.dat.data.T[:, self._reordering_arr_fd2mm]
-        by_cell_reordered_view = self.to_discr.groups[0].view(reordered_outdata)
+        el_group = self.discr.groups[self.group_nr]
+        by_cell_field_view = el_group.view(mm_field)
+
+        # Get firedrake data from out into meshmode ordering and view by cell
+        reordered_outdata = \
+            np.moveaxis(out.dat.data, 0, -1)[..., self.mm2fd_node_mapping]
+        by_cell_reordered_view = el_group.view(reordered_outdata)
+        # Resample this reordered data
         np.matmul(by_cell_field_view, self._resampling_mat_mm2fd.T,
                   out=by_cell_reordered_view)
-        out.dat.data[self._reordering_arr_fd2mm] = reordered_outdata.T
+        # Now store the resampled data back in the firedrake order
+        out.dat.data[self.mm2fd_node_mapping] = \
+            np.moveaxis(reordered_outdata, -1, 0)
 
-        # Continuity checks
+        # Continuity checks if requested
         if self._ufl_element.family() == 'Lagrange' \
                 and continuity_tolerance is not None:
             assert isinstance(continuity_tolerance, float)
@@ -346,12 +416,8 @@ class FromFiredrakeConnection:
                 # nodes may have been resampled to distinct nodes on different
                 # elements. reordered_outdata has undone that resampling.
                 for dup_mm_inode in duplicated_mm_nodes[1:]:
-                    if len(reordered_outdata.shape) > 1:
-                        dist = la.norm(reordered_outdata[:, mm_inode]
-                                       - reordered_outdata[:, dup_mm_inode])
-                    else:
-                        dist = la.norm(reordered_outdata[mm_inode]
-                                       - reordered_outdata[dup_mm_inode])
+                    dist = la.norm(reordered_outdata[..., mm_inode]
+                                   - reordered_outdata[..., dup_mm_inode])
                     if dist >= continuity_tolerance:
                         raise ValueError("Meshmode nodes %s and %s represent "
                                          "the same firedrake node %s, but "
@@ -365,19 +431,114 @@ class FromFiredrakeConnection:
 # }}}
 
 
+# {{{ Create connection from firedrake into meshmode
+
+class FromFiredrakeConnection(FiredrakeConnection):
+    """
+    A connection created from a :mod:`firedrake`
+    ``"CG"`` or ``"DG"`` function space which creates a corresponding
+    meshmode discretization and allows
+    transfer of functions to and from :mod:`firedrake`.
+    """
+    def __init__(self, cl_ctx, fdrake_fspace):
+        """
+        :arg cl_ctx: A :mod:`pyopencl` computing context
+        :arg fdrake_fspace: A :mod:`firedrake` ``"CG"`` or ``"DG"``
+            function space (of class :class:`WithGeometry`) built on
+            a mesh which is importable by :func:`import_firedrake_mesh`.
+        """
+        # Ensure fdrake_fspace is a function space with appropriate reference
+        # element.
+        from firedrake.functionspaceimpl import WithGeometry
+        if not isinstance(fdrake_fspace, WithGeometry):
+            raise TypeError(":arg:`fdrake_fspace` must be of firedrake type "
+                            ":class:`WithGeometry`, not `%s`."
+                            % type(fdrake_fspace))
+        ufl_elt = fdrake_fspace.ufl_element()
+
+        if ufl_elt.family() not in ('Lagrange', 'Discontinuous Lagrange'):
+            raise ValueError("the ``ufl_element().family()`` of "
+                             ":arg:`fdrake_fspace` must "
+                             "be ``'Lagrange'`` or "
+                             "``'Discontinuous Lagrange'``, not %s."
+                             % ufl_elt.family())
+
+        # Create to_discr
+        mm_mesh, orient = import_firedrake_mesh(fdrake_fspace.mesh())
+        factory = InterpolatoryQuadratureSimplexGroupFactory(ufl_elt.degree())
+        to_discr = Discretization(cl_ctx, mm_mesh, factory)
+
+        # get firedrake unit nodes and map onto meshmode reference element
+        group = to_discr.groups[0]
+        fd_ref_cell_to_mm = get_affine_reference_simplex_mapping(group.dim,
+                                                                 True)
+        fd_unit_nodes = get_finat_element_unit_nodes(fdrake_fspace.finat_element)
+        fd_unit_nodes = fd_ref_cell_to_mm(fd_unit_nodes)
+        # Flipping negative elements corresponds to reordering the nodes.
+        # We handle reordering by storing the permutation explicitly as
+        # a numpy array
+
+        # Get the reordering fd->mm.
+        #
+        # One should note there is something a bit more subtle going on
+        # in the continuous case. All meshmode discretizations use
+        # are discontinuous, so nodes are associated with elements(cells)
+        # not vertices. In a continuous firedrake space, some nodes
+        # are shared between multiple cells. In particular, while the
+        # below "reordering" is indeed a permutation if the firedrake space
+        # is discontinuous, if the firedrake space is continuous then
+        # some firedrake nodes correspond to nodes on multiple meshmode
+        # elements, i.e. those nodes appear multiple times
+        # in the "reordering" array
+        flip_mat = get_simplex_element_flip_matrix(ufl_elt.degree(),
+                                                   fd_unit_nodes)
+        fd_cell_node_list = fdrake_fspace.cell_node_list
+        _reorder_nodes(orient, fd_cell_node_list, flip_mat, unflip=False)
+        mm2fd_node_mapping = fd_cell_node_list.flatten()
+
+        super(FromFiredrakeConnection, self).__init__(to_discr,
+                                                      fdrake_fspace,
+                                                      mm2fd_node_mapping)
+        if fdrake_fspace.ufl_element().family() == 'Discontinuous Lagrange':
+            assert len(self._duplicate_nodes) == 0, \
+                "Somehow a firedrake node in a 'DG' space got duplicated..." \
+                "contact the developer."
+
+# }}}
+
+
 # {{{ Create connection to firedrake from meshmode
 
 
-# TODO : implement this (should be easy using export_mesh_to_firedrake
-#        and similar styles)
-class ToFiredrakeConnection:
-    def __init__(self, discr, group_nr=None):
+class ToFiredrakeConnection(FiredrakeConnection):
+    """
+    Create a connection from a firedrake discretization
+    into firedrake. Create a corresponding "DG" function
+    space and allow for conversion back and forth
+    by resampling at the nodes.
+    """
+    def __init__(self, discr, group_nr=None, comm=None):
         """
-        Create a connection from a firedrake discretization
-        into firedrake. Create a corresponding "DG" function
-        space and allow for conversion back and forth
-        by resampling at the nodes.
+        :param discr: A :class:`Discretization` to intialize the connection with
+        :param group_nr: The group number of the discretization to convert.
+            If *None* there must be only one group. The selected group
+            must be of type :class:`InterpolatoryQuadratureSimplexElementGroup`
+        :param comm: Communicator to build a dmplex object on for the created
+            firedrake mesh
         """
+        if group_nr is None:
+            assert len(discr.groups) == 1, ":arg:`group_nr` is *None*, but " \
+                    ":arg:`discr` has %s != 1 groups." % len(discr.groups)
+            group_nr = 0
+        el_group = discr.groups[group_nr]
+
+        from firedrake.functionspace import FunctionSpace
+        fd_mesh = export_mesh_to_firedrake(discr.mesh, group_nr, comm)
+        fspace = FunctionSpace(fd_mesh, 'DG', el_group.order)
+        super(ToFiredrakeConnection, self).__init__(discr,
+                                                    fspace,
+                                                    np.arange(el_group.nnodes),
+                                                    group_nr=group_nr)
 
 # }}}
 
diff --git a/test/test_firedrake_interop.py b/test/test_firedrake_interop.py
index 94174760..b3740f24 100644
--- a/test/test_firedrake_interop.py
+++ b/test/test_firedrake_interop.py
@@ -81,6 +81,9 @@ def fdrake_degree(request):
     return request.param
 
 
+# TODO : make some tests to verify boundary tagging
+
+
 # {{{ Basic conversion checks for the function space
 
 def test_discretization_consistency(ctx_factory, fdrake_mesh, fdrake_degree):
@@ -99,17 +102,17 @@ def test_discretization_consistency(ctx_factory, fdrake_mesh, fdrake_degree):
     fdrake_fspace = FunctionSpace(fdrake_mesh, 'DG', fdrake_degree)
     cl_ctx = ctx_factory()
     fdrake_connection = FromFiredrakeConnection(cl_ctx, fdrake_fspace)
-    to_discr = fdrake_connection.to_discr
-    meshmode_verts = to_discr.mesh.vertices
+    discr = fdrake_connection.discr
+    meshmode_verts = discr.mesh.vertices
 
     # Ensure the meshmode mesh has one group and make sure both
     # meshes agree on some basic properties
-    assert len(to_discr.mesh.groups) == 1
+    assert len(discr.mesh.groups) == 1
     fdrake_mesh_fspace = fdrake_mesh.coordinates.function_space()
     fdrake_mesh_order = fdrake_mesh_fspace.finat_element.degree
-    assert to_discr.mesh.groups[0].order == fdrake_mesh_order
-    assert to_discr.mesh.groups[0].nelements == fdrake_mesh.num_cells()
-    assert to_discr.mesh.nvertices == fdrake_mesh.num_vertices()
+    assert discr.mesh.groups[0].order == fdrake_mesh_order
+    assert discr.mesh.groups[0].nelements == fdrake_mesh.num_cells()
+    assert discr.mesh.nvertices == fdrake_mesh.num_vertices()
 
     # Ensure that the vertex sets are identical up to reordering
     # Nb: I got help on this from stack overflow:
@@ -121,10 +124,11 @@ def test_discretization_consistency(ctx_factory, fdrake_mesh, fdrake_degree):
     # Ensure the discretization and the firedrake function space agree on
     # some basic properties
     finat_elt = fdrake_fspace.finat_element
-    assert len(to_discr.groups) == 1
-    assert to_discr.groups[0].order == finat_elt.degree
-    assert to_discr.groups[0].nunit_nodes == finat_elt.space_dimension()
-    assert to_discr.nnodes == fdrake_fspace.node_count
+    assert len(discr.groups) == 1
+    assert discr.groups[0].order == finat_elt.degree
+    assert discr.groups[0].nunit_nodes == finat_elt.space_dimension()
+    assert discr.nnodes == fdrake_fspace.node_count
+
 
 # }}}
 
@@ -181,9 +185,9 @@ def test_function_transfer(ctx_factory,
     fdrake_connection = FromFiredrakeConnection(cl_ctx, fdrake_fspace)
     transported_f = fdrake_connection.from_firedrake(fdrake_f)
 
-    to_discr = fdrake_connection.to_discr
+    discr = fdrake_connection.discr
     with cl.CommandQueue(cl_ctx) as queue:
-        nodes = to_discr.nodes().get(queue=queue)
+        nodes = discr.nodes().get(queue=queue)
     meshmode_f = meshmode_f_eval(nodes)
 
     np.testing.assert_allclose(transported_f, meshmode_f, atol=CLOSE_ATOL)
@@ -200,7 +204,7 @@ def check_idempotency(fdrake_connection, fdrake_function):
     vdim = None
     if len(fdrake_function.dat.data.shape) > 1:
         vdim = fdrake_function.dat.data.shape[1]
-    fdrake_fspace = fdrake_connection.from_fspace(dim=vdim)
+    fdrake_fspace = fdrake_connection.firedrake_fspace(vdim=vdim)
 
     # Test for idempotency fd->mm->fd
     mm_field = fdrake_connection.from_firedrake(fdrake_function)
-- 
GitLab