From 5664f3d2beea8e9e2d1c40af97a1360ecfcc8d82 Mon Sep 17 00:00:00 2001
From: benSepanski <ben_sepanski@alumni.baylor.edu>
Date: Sun, 16 Aug 2020 08:46:40 -0500
Subject: [PATCH] First stab at changing From/ToFiredrakeConnection to factory
 functions

---
 doc/interop.rst                          |  14 +-
 examples/from_firedrake.py               |  24 +-
 examples/to_firedrake.py                 |   4 +-
 meshmode/interop/firedrake/__init__.py   |   8 +-
 meshmode/interop/firedrake/connection.py | 424 +++++++++++------------
 test/test_firedrake_interop.py           |  36 +-
 6 files changed, 242 insertions(+), 268 deletions(-)

diff --git a/doc/interop.rst b/doc/interop.rst
index 78caab70..c56665e5 100644
--- a/doc/interop.rst
+++ b/doc/interop.rst
@@ -12,15 +12,15 @@ Function Spaces/Discretizations
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 Users wishing to interact with :mod:`meshmode` from :mod:`firedrake`
-will primarily interact with the
-:class:`~meshmode.interop.firedrake.connection.FromFiredrakeConnection` and
-:class:`~meshmode.interop.firedrake.connection.FromBoundaryFiredrakeConnection`
-classes, while users wishing
+will create a 
+:class:`~meshmode.interop.firedrake.connection.FiredrakeConnection`
+using :func:`~meshmode.interop.firedrake.connection.build_connection_from_firedrake`,
+while users wishing
 to interact with :mod:`firedrake` from :mod:`meshmode` will use
+will create a 
+:class:`~meshmode.interop.firedrake.connection.FiredrakeConnection`
+using :func:`~meshmode.interop.firedrake.connection.build_connection_to_firedrake`.
 the :class:`~meshmode.interop.firedrake.connection.ToFiredrakeConnection` class.
-All of these classes inherit from
-the :class:`~meshmode.interop.firedrake.connection.FiredrakeConnection`
-class, which provides the interface.
 It is not recommended to create a
 :class:`~meshmode.interop.firedrake.connection.FiredrakeConnection` directly.
 
diff --git a/examples/from_firedrake.py b/examples/from_firedrake.py
index 87bcbd9c..624d1dc4 100644
--- a/examples/from_firedrake.py
+++ b/examples/from_firedrake.py
@@ -36,8 +36,7 @@ def main():
     except ImportError:
         return 0
 
-    from meshmode.interop.firedrake import (
-        FromFiredrakeConnection, FromBoundaryFiredrakeConnection)
+    from meshmode.interop.firedrake import build_connection_from_firedrake
     from firedrake import (
         UnitSquareMesh, FunctionSpace, SpatialCoordinate, Function, cos
         )
@@ -54,22 +53,23 @@ def main():
     from meshmode.array_context import PyOpenCLArrayContext
     actx = PyOpenCLArrayContext(queue)
 
-    fd_connection = FromFiredrakeConnection(actx, fd_fspace)
-    fd_bdy_connection = FromBoundaryFiredrakeConnection(actx,
-                                                   fd_fspace,
-                                                   'on_boundary')
+    fd_connection = build_connection_from_firedrake(actx, fd_fspace)
+    fd_bdy_connection = \
+        build_connection_from_firedrake(actx,
+                                        fd_fspace,
+                                        restrict_to_boundary='on_boundary')
 
     # Plot the meshmode meshes that the connections connect to
     import matplotlib.pyplot as plt
     from meshmode.mesh.visualization import draw_2d_mesh
     fig, (ax1, ax2) = plt.subplots(1, 2)
-    ax1.set_title("FromFiredrakeConnection")
+    ax1.set_title("FiredrakeConnection")
     plt.sca(ax1)
     draw_2d_mesh(fd_connection.discr.mesh,
                  draw_vertex_numbers=False,
                  draw_element_numbers=False,
                  set_bounding_box=True)
-    ax2.set_title("FromBoundaryFiredrakeConnection")
+    ax2.set_title("FiredrakeConnection 'on_boundary'")
     plt.sca(ax2)
     draw_2d_mesh(fd_bdy_connection.discr.mesh,
                  draw_vertex_numbers=False,
@@ -77,7 +77,7 @@ def main():
                  set_bounding_box=True)
     plt.show()
 
-    # Plot fd_fntn using FromFiredrakeConnection
+    # Plot fd_fntn using unrestricted FiredrakeConnection
     from meshmode.discretization.visualization import make_visualizer
     discr = fd_connection.discr
     vis = make_visualizer(actx, discr, discr.groups[0].order+3)
@@ -85,17 +85,17 @@ def main():
 
     fig = plt.figure()
     ax1 = fig.add_subplot(1, 2, 1, projection='3d')
-    ax1.set_title("cos(x+y) in\nFromFiredrakeConnection")
+    ax1.set_title("cos(x+y) in\nFiredrakeConnection")
     vis.show_scalar_in_matplotlib_3d(field, do_show=False)
 
-    # Now repeat using FromBoundaryFiredrakeConnection
+    # Now repeat using FiredrakeConnection restricted to 'on_boundary'
     bdy_discr = fd_bdy_connection.discr
     bdy_vis = make_visualizer(actx, bdy_discr, bdy_discr.groups[0].order+3)
     bdy_field = fd_bdy_connection.from_firedrake(fd_fntn, actx=actx)
 
     ax2 = fig.add_subplot(1, 2, 2, projection='3d')
     plt.sca(ax2)
-    ax2.set_title("cos(x+y) in\nFromBoundaryFiredrakeConnection")
+    ax2.set_title("cos(x+y) in\nFiredrakeConnection 'on_boundary'")
     bdy_vis.show_scalar_in_matplotlib_3d(bdy_field, do_show=False)
 
     import matplotlib.cm as cm
diff --git a/examples/to_firedrake.py b/examples/to_firedrake.py
index 208a9054..fb1a6bef 100644
--- a/examples/to_firedrake.py
+++ b/examples/to_firedrake.py
@@ -82,8 +82,8 @@ def main():
 
     # {{{ Now send candidate_sol into firedrake and use it for boundary conditions
 
-    from meshmode.interop.firedrake import ToFiredrakeConnection
-    fd_connection = ToFiredrakeConnection(discr, group_nr=0)
+    from meshmode.interop.firedrake import build_connection_to_firedrake
+    fd_connection = build_connection_to_firedrake(discr, group_nr=0)
     # convert candidate_sol to firedrake
     fd_candidate_sol = fd_connection.from_meshmode(candidate_sol)
     # get the firedrake function space
diff --git a/meshmode/interop/firedrake/__init__.py b/meshmode/interop/firedrake/__init__.py
index f86d9de7..a4bc55a7 100644
--- a/meshmode/interop/firedrake/__init__.py
+++ b/meshmode/interop/firedrake/__init__.py
@@ -22,12 +22,12 @@ THE SOFTWARE.
 
 
 from meshmode.interop.firedrake.connection import (
-    FromBoundaryFiredrakeConnection, FromFiredrakeConnection,
-    ToFiredrakeConnection)
+    build_connection_from_firedrake, build_connection_to_firedrake,
+    FiredrakeConnection)
 from meshmode.interop.firedrake.mesh import (
     import_firedrake_mesh, export_mesh_to_firedrake)
 
-__all__ = ["FromBoundaryFiredrakeConnection", "FromFiredrakeConnection",
-           "ToFiredrakeConnection", "import_firedrake_mesh",
+__all__ = ["build_connection_from_firedrake", "build_connection_to_firedrake",
+           "FiredrakeConnection", "import_firedrake_mesh",
            "export_mesh_to_firedrake"
            ]
diff --git a/meshmode/interop/firedrake/connection.py b/meshmode/interop/firedrake/connection.py
index 057dd911..9b77453d 100644
--- a/meshmode/interop/firedrake/connection.py
+++ b/meshmode/interop/firedrake/connection.py
@@ -22,9 +22,8 @@ THE SOFTWARE.
 
 __doc__ = """
 .. autoclass:: FiredrakeConnection
-.. autoclass:: FromFiredrakeConnection
-.. autoclass:: FromBoundaryFiredrakeConnection
-.. autoclass:: ToFiredrakeConnection
+.. autofunction:: build_connection_to_firedrake
+.. autofunction:: build_connection_from_firedrake
 """
 
 import numpy as np
@@ -89,11 +88,12 @@ def _reorder_nodes(orient, nodes, flip_matrix, unflip=False):
 class FiredrakeConnection:
     """
     A connection between one group of
-    a meshmode discretization and a firedrake "CG" or "DG"
+    a meshmode discretization and a firedrake "DG"
     function space.
 
-    Users should instantiate this using a
-    :class:`FromFiredrakeConnection` or :class:`ToFiredrakeConnection`.
+    Users should instantiate this using
+    :func:`build_connection_to_firedrake`
+    or :func:`build_connection_from_firedrake`
 
     .. attribute:: discr
 
@@ -561,169 +561,143 @@ class FiredrakeConnection:
 # {{{ Create connection from firedrake into meshmode
 
 
-class FromFiredrakeConnection(FiredrakeConnection):
+def build_connection_from_firedrake(actx, fdrake_fspace, grp_factory=None,
+                                    restrict_to_boundary=None):
+
     """
-    A connection created from a :mod:`firedrake`
-    ``"CG"`` or ``"DG"`` function space which creates a corresponding
-    meshmode discretization and allows
+    Create a :class:`FiredrakeConnection` from a :mod:`firedrake`
+    ``"DG"`` function space by creates a corresponding
+    meshmode discretization and facilitating
     transfer of functions to and from :mod:`firedrake`.
 
-    .. automethod:: __init__
-    """
-    def __init__(self, actx, fdrake_fspace, grp_factory=None):
-        """
-        :arg actx: A :class:`~meshmode.array_context.ArrayContext`
-            used to instantiate :attr:`FiredrakeConnection.discr`.
-        :arg fdrake_fspace: A :mod:`firedrake` ``"CG"`` or ``"DG"``
-            function space (of class
-            :class:`~firedrake.functionspaceimpl.WithGeometry`) built on
-            a mesh which is importable by
-            :func:`~meshmode.interop.firedrake.mesh.import_firedrake_mesh`.
-        :arg grp_factory: (optional) If not *None*, should be
-            a :class:`~meshmode.discretization.poly_element.ElementGroupFactory`
-            whose group class is a subclass of
-            :class:`~meshmode.discretization.InterpolatoryElementGroupBase`.
-            If *None*, and :mod:`recursivenodes` can be imported,
-            a :class:`~meshmode.discretization.poly_element.\
+    :arg actx: A :class:`~meshmode.array_context.ArrayContext`
+        used to instantiate :attr:`FiredrakeConnection.discr`.
+    :arg fdrake_fspace: A :mod:`firedrake` ``"DG"``
+        function space (of class
+        :class:`~firedrake.functionspaceimpl.WithGeometry`) built on
+        a mesh which is importable by
+        :func:`~meshmode.interop.firedrake.mesh.import_firedrake_mesh`.
+    :arg grp_factory: (optional) If not *None*, should be
+        a :class:`~meshmode.discretization.poly_element.ElementGroupFactory`
+        whose group class is a subclass of
+        :class:`~meshmode.discretization.InterpolatoryElementGroupBase`.
+        If *None*, and :mod:`recursivenodes` can be imported,
+        a :class:`~meshmode.discretization.poly_element.\
 PolynomialRecursiveNodesGroupFactory` with ``'lgl'`` nodes is used.
-            Note that :mod:`recursivenodes` may not be importable
-            as it uses :func:`math.comb`, which is new in Python 3.8.
-            In the case that :mod:`recursivenodes` cannot be successfully
-            imported, a :class:`~meshmode.discretization.poly_element.\
+        Note that :mod:`recursivenodes` may not be importable
+        as it uses :func:`math.comb`, which is new in Python 3.8.
+        In the case that :mod:`recursivenodes` cannot be successfully
+        imported, a :class:`~meshmode.discretization.poly_element.\
 PolynomialWarpAndBlendGroupFactory` is used.
-        """
-        # Ensure fdrake_fspace is a function space with appropriate reference
-        # element.
-        from firedrake.functionspaceimpl import WithGeometry
-        if not isinstance(fdrake_fspace, WithGeometry):
-            raise TypeError("'fdrake_fspace' must be of firedrake type "
-                            "WithGeometry, not '%s'."
-                            % type(fdrake_fspace))
-        ufl_elt = fdrake_fspace.ufl_element()
-
-        if ufl_elt.family() != 'Discontinuous Lagrange':
-            raise ValueError("the 'fdrake_fspace.ufl_element().family()' of "
-                             "must be be "
-                             "'Discontinuous Lagrange', not '%s'."
-                             % ufl_elt.family())
-        # Make sure grp_factory is the right type if provided, and
-        # uses an interpolatory class.
-        if grp_factory is not None:
-            if not isinstance(grp_factory, ElementGroupFactory):
-                raise TypeError("'grp_factory' must inherit from "
-                                "meshmode.discretization.ElementGroupFactory,"
-                                "but is instead of type "
-                                "'%s'." % type(grp_factory))
-            if not issubclass(grp_factory.group_class,
-                              InterpolatoryElementGroupBase):
-                raise TypeError("'grp_factory.group_class' must inherit from"
-                                "meshmode.discretization."
-                                "InterpolatoryElementGroupBase, but"
-                                " is instead of type '%s'"
-                                % type(grp_factory.group_class))
-        # If not provided, make one
-        else:
-            degree = ufl_elt.degree()
-            try:
-                # recursivenodes is only importable in Python 3.8 since
-                # it uses :func:`math.comb`, so need to check if it can
-                # be imported
-                import recursivenodes  # noqa : F401
-                family = 'lgl'  # L-G-Legendre
-                grp_factory = PolynomialRecursiveNodesGroupFactory(degree, family)
-            except ImportError:
-                # If cannot be imported, uses warp-and-blend nodes
-                grp_factory = PolynomialWarpAndBlendGroupFactory(degree)
-
-        # In case this class is really a FromBoundaryFiredrakeConnection,
-        # get *cells_to_use*
-        cells_to_use = self._get_cells_to_use(fdrake_fspace.mesh())
-        # Create to_discr
-        mm_mesh, orient = import_firedrake_mesh(fdrake_fspace.mesh(),
-                                                cells_to_use=cells_to_use)
-        to_discr = Discretization(actx, mm_mesh, grp_factory)
-
-        # get firedrake unit nodes and map onto meshmode reference element
-        group = to_discr.groups[0]
-        fd_ref_cell_to_mm = get_affine_reference_simplex_mapping(group.dim,
-                                                                 True)
-        fd_unit_nodes = get_finat_element_unit_nodes(fdrake_fspace.finat_element)
-        fd_unit_nodes = fd_ref_cell_to_mm(fd_unit_nodes)
-        # Flipping negative elements corresponds to reordering the nodes.
-        # We handle reordering by storing the permutation explicitly as
-        # a numpy array
-
-        # Get the reordering fd->mm.
-        flip_mat = get_simplex_element_flip_matrix(ufl_elt.degree(),
-                                                   fd_unit_nodes)
-        fd_cell_node_list = fdrake_fspace.cell_node_list
-        if cells_to_use is not None:
-            fd_cell_node_list = fd_cell_node_list[cells_to_use]
-        # flip fd_cell_node_list
-        flipped_cell_node_list = _reorder_nodes(orient,
-                                                fd_cell_node_list,
-                                                flip_mat,
-                                                unflip=False)
-
-        assert np.size(np.unique(flipped_cell_node_list)) == \
-            np.size(flipped_cell_node_list), \
-            "A firedrake node in a 'DG' space got duplicated"
-        super(FromFiredrakeConnection, self).__init__(to_discr,
-                                                      fdrake_fspace,
-                                                      flipped_cell_node_list)
-
-    def _get_cells_to_use(self, mesh):
-        """
-        For compatibility with :class:`FromFiredrakeBdyConnection`
-        """
-        return None
-
-
-class FromBoundaryFiredrakeConnection(FromFiredrakeConnection):
-    """
-    A connection created from a :mod:`firedrake`
-    ``"CG"`` or ``"DG"`` function space which creates a
-    meshmode discretization corresponding to all cells with at
-    least one vertex on the given boundary and allows
-    transfer of functions to and from :mod:`firedrake`.
-
-    Use the same bdy_id as one would for a
-    :class:`firedrake.bcs.DirichletBC` instance.
-    ``"on_boundary"`` corresponds to the entire boundary.
-
-    .. attribute:: bdy_id
-
-        the boundary id of the boundary being connecting from
-
-    .. automethod:: __init__
+    :arg restrict_to_boundary: (optional)
+        If not *None*, then must be a valid boundary marker for
+        ``fdrake_fspace.mesh()``. In this case, creates a
+        :class:`~meshmode.discretization.Discretization` on a submesh
+        of ``fdrake_fspace.mesh()`` created from the cells with at least
+        one vertex on a facet marked with the marker
+        *restrict_to_boundary*.
     """
-    def __init__(self, actx, fdrake_fspace, bdy_id, grp_factory=None):
-        """
-        :arg bdy_id: A boundary marker of *fdrake_fspace.mesh()* as accepted by
-            the *boundary_nodes* method of a
-            :class:`firedrake.functionspaceimpl.WithGeometry`.
-
-        Other arguments are as in
-        :class:`~meshmode.interop.firedrake.connection.FromFiredrakeConnection`.
-        """
-        self.bdy_id = bdy_id
-        super(FromBoundaryFiredrakeConnection, self).__init__(
-            actx, fdrake_fspace, grp_factory=grp_factory)
-
-    def _get_cells_to_use(self, mesh):
-        """
-        Returns an array of the cell ids with >= 1 vertex on the
-        given bdy_id
-        """
-        cfspace = mesh.coordinates.function_space()
+    # Ensure fdrake_fspace is a function space with appropriate reference
+    # element.
+    from firedrake.functionspaceimpl import WithGeometry
+    if not isinstance(fdrake_fspace, WithGeometry):
+        raise TypeError("'fdrake_fspace' must be of firedrake type "
+                        "WithGeometry, not '%s'."
+                        % type(fdrake_fspace))
+    ufl_elt = fdrake_fspace.ufl_element()
+
+    if ufl_elt.family() != 'Discontinuous Lagrange':
+        raise ValueError("the 'fdrake_fspace.ufl_element().family()' of "
+                         "must be be "
+                         "'Discontinuous Lagrange', not '%s'."
+                         % ufl_elt.family())
+    # Make sure grp_factory is the right type if provided, and
+    # uses an interpolatory class.
+    if grp_factory is not None:
+        if not isinstance(grp_factory, ElementGroupFactory):
+            raise TypeError("'grp_factory' must inherit from "
+                            "meshmode.discretization.ElementGroupFactory,"
+                            "but is instead of type "
+                            "'%s'." % type(grp_factory))
+        if not issubclass(grp_factory.group_class,
+                          InterpolatoryElementGroupBase):
+            raise TypeError("'grp_factory.group_class' must inherit from"
+                            "meshmode.discretization."
+                            "InterpolatoryElementGroupBase, but"
+                            " is instead of type '%s'"
+                            % type(grp_factory.group_class))
+    # If not provided, make one
+    else:
+        degree = ufl_elt.degree()
+        try:
+            # recursivenodes is only importable in Python 3.8 since
+            # it uses :func:`math.comb`, so need to check if it can
+            # be imported
+            import recursivenodes  # noqa : F401
+            family = 'lgl'  # L-G-Legendre
+            grp_factory = PolynomialRecursiveNodesGroupFactory(degree, family)
+        except ImportError:
+            # If cannot be imported, uses warp-and-blend nodes
+            grp_factory = PolynomialWarpAndBlendGroupFactory(degree)
+    if restrict_to_boundary is not None:
+        uniq_markers = fdrake_fspace.mesh().exterior_facets.unique_markers
+        allowable_bdy_ids = list(uniq_markers) + ["on_boundary"]
+        if restrict_to_boundary not in allowable_bdy_ids:
+            raise ValueError("'restrict_to_boundary' must be one of"
+                            " the following allowable boundary ids: "
+                            f"{allowable_bdy_ids}, not "
+                            f"'{restrict_to_boundary}'")
+
+    # If only converting a portion of the mesh near the boundary, get
+    # *cells_to_use* as described in
+    # :func:`meshmode.interop.firedrake.mesh.import_firedrake_mesh`
+    cells_to_use = None
+    if restrict_to_boundary is not None:
+        cfspace = fdrake_fspace.mesh().coordinates.function_space()
         cell_node_list = cfspace.cell_node_list
 
-        boundary_nodes = cfspace.boundary_nodes(self.bdy_id, 'topological')
+        boundary_nodes = cfspace.boundary_nodes(restrict_to_boundary,
+                                                'topological')
         # Reduce along each cell: Is a vertex of the cell in boundary nodes?
         cell_is_near_bdy = np.any(np.isin(cell_node_list, boundary_nodes), axis=1)
 
         from pyop2.datatypes import IntType
-        return np.nonzero(cell_is_near_bdy)[0].astype(IntType)
+        cells_to_use = np.nonzero(cell_is_near_bdy)[0].astype(IntType)
+
+    # Create to_discr
+    mm_mesh, orient = import_firedrake_mesh(fdrake_fspace.mesh(),
+                                            cells_to_use=cells_to_use)
+    to_discr = Discretization(actx, mm_mesh, grp_factory)
+
+    # get firedrake unit nodes and map onto meshmode reference element
+    group = to_discr.groups[0]
+    fd_ref_cell_to_mm = get_affine_reference_simplex_mapping(group.dim,
+                                                             True)
+    fd_unit_nodes = get_finat_element_unit_nodes(fdrake_fspace.finat_element)
+    fd_unit_nodes = fd_ref_cell_to_mm(fd_unit_nodes)
+    # Flipping negative elements corresponds to reordering the nodes.
+    # We handle reordering by storing the permutation explicitly as
+    # a numpy array
+
+    # Get the reordering fd->mm.
+    flip_mat = get_simplex_element_flip_matrix(ufl_elt.degree(),
+                                               fd_unit_nodes)
+    fd_cell_node_list = fdrake_fspace.cell_node_list
+    if cells_to_use is not None:
+        fd_cell_node_list = fd_cell_node_list[cells_to_use]
+    # flip fd_cell_node_list
+    flipped_cell_node_list = _reorder_nodes(orient,
+                                            fd_cell_node_list,
+                                            flip_mat,
+                                            unflip=False)
+
+    assert np.size(np.unique(flipped_cell_node_list)) == \
+        np.size(flipped_cell_node_list), \
+        "A firedrake node in a 'DG' space got duplicated"
+
+    return FiredrakeConnection(to_discr,
+                               fdrake_fspace,
+                               flipped_cell_node_list)
 
 # }}}
 
@@ -731,84 +705,80 @@ class FromBoundaryFiredrakeConnection(FromFiredrakeConnection):
 # {{{ Create connection to firedrake from meshmode
 
 
-class ToFiredrakeConnection(FiredrakeConnection):
+def build_connection_to_firedrake(discr, group_nr=None, comm=None):
     """
     Create a connection from a meshmode discretization
     into firedrake. Create a corresponding "DG" function
     space and allow for conversion back and forth
     by resampling at the nodes.
 
-    .. automethod:: __init__
-    """
-    def __init__(self, discr, group_nr=None, comm=None):
-        """
-        :param discr: A :class:`~meshmode.discretization.Discretization`
-            to intialize the connection with
-        :param group_nr: The group number of the discretization to convert.
-            If *None* there must be only one group. The selected group
-            must be of type
-            :class:`~meshmode.discretization.poly_element.\
+    :param discr: A :class:`~meshmode.discretization.Discretization`
+        to intialize the connection with
+    :param group_nr: The group number of the discretization to convert.
+        If *None* there must be only one group. The selected group
+        must be of type
+        :class:`~meshmode.discretization.poly_element.\
 InterpolatoryQuadratureSimplexElementGroup`.
 
-        :param comm: Communicator to build a dmplex object on for the created
-            firedrake mesh
-        """
-        if group_nr is None:
-            if len(discr.groups) != 1:
-                raise ValueError("'group_nr' is *None*, but 'discr' has '%s' "
-                                 "!= 1 groups." % len(discr.groups))
-            group_nr = 0
-        el_group = discr.groups[group_nr]
-
-        from firedrake.functionspace import FunctionSpace
-        fd_mesh, fd_cell_order, perm2cells = \
-            export_mesh_to_firedrake(discr.mesh, group_nr, comm)
-        fspace = FunctionSpace(fd_mesh, 'DG', el_group.order)
-        # get firedrake unit nodes and map onto meshmode reference element
-        dim = fspace.mesh().topological_dimension()
-        fd_ref_cell_to_mm = get_affine_reference_simplex_mapping(dim, True)
-        fd_unit_nodes = get_finat_element_unit_nodes(fspace.finat_element)
-        fd_unit_nodes = fd_ref_cell_to_mm(fd_unit_nodes)
-
-        # **_cell_node holds the node nrs in shape *(ncells, nunit_nodes)*
-        fd_cell_node = fspace.cell_node_list
-
-        # To get the meshmode to firedrake node assocation, we need to handle
-        # local vertex reordering and cell reordering.
-        from pyop2.datatypes import IntType
-        mm2fd_node_mapping = np.ndarray((el_group.nelements, el_group.nunit_dofs),
-                                        dtype=IntType)
-        for perm, cells in perm2cells.items():
-            # reordering_arr[i] should be the fd node corresponding to meshmode
-            # node i
-            #
-            # The jth meshmode cell corresponds to the fd_cell_order[j]th
-            # firedrake cell. If *nodeperm* is the permutation of local nodes
-            # applied to the *j*th meshmode cell, the firedrake node
-            # fd_cell_node[fd_cell_order[j]][k] corresponds to the
-            # mm_cell_node[j, nodeperm[k]]th meshmode node.
-            #
-            # Note that the permutation on the unit nodes may not be the
-            # same as the permutation on the barycentric coordinates (*perm*).
-            # Importantly, the permutation is derived from getting a flip
-            # matrix from the Firedrake unit nodes, not necessarily the meshmode
-            # unit nodes
-            #
-            flip_mat = get_simplex_element_flip_matrix(el_group.order,
-                                                       fd_unit_nodes,
-                                                       np.argsort(perm))
-            flip_mat = np.rint(flip_mat).astype(IntType)
-            fd_permuted_cell_node = np.matmul(fd_cell_node[fd_cell_order[cells]],
-                                              flip_mat.T)
-            mm2fd_node_mapping[cells] = fd_permuted_cell_node
-
-        assert np.size(np.unique(mm2fd_node_mapping)) == \
-            np.size(mm2fd_node_mapping), \
-            "A firedrake node in a 'DG' space got duplicated"
-        super(ToFiredrakeConnection, self).__init__(discr,
-                                                    fspace,
-                                                    mm2fd_node_mapping,
-                                                    group_nr=group_nr)
+    :param comm: Communicator to build a dmplex object on for the created
+        firedrake mesh
+    """
+    if group_nr is None:
+        if len(discr.groups) != 1:
+            raise ValueError("'group_nr' is *None*, but 'discr' has '%s' "
+                             "!= 1 groups." % len(discr.groups))
+        group_nr = 0
+    el_group = discr.groups[group_nr]
+
+    from firedrake.functionspace import FunctionSpace
+    fd_mesh, fd_cell_order, perm2cells = \
+        export_mesh_to_firedrake(discr.mesh, group_nr, comm)
+    fspace = FunctionSpace(fd_mesh, 'DG', el_group.order)
+    # get firedrake unit nodes and map onto meshmode reference element
+    dim = fspace.mesh().topological_dimension()
+    fd_ref_cell_to_mm = get_affine_reference_simplex_mapping(dim, True)
+    fd_unit_nodes = get_finat_element_unit_nodes(fspace.finat_element)
+    fd_unit_nodes = fd_ref_cell_to_mm(fd_unit_nodes)
+
+    # **_cell_node holds the node nrs in shape *(ncells, nunit_nodes)*
+    fd_cell_node = fspace.cell_node_list
+
+    # To get the meshmode to firedrake node assocation, we need to handle
+    # local vertex reordering and cell reordering.
+    from pyop2.datatypes import IntType
+    mm2fd_node_mapping = np.ndarray((el_group.nelements, el_group.nunit_dofs),
+                                    dtype=IntType)
+    for perm, cells in perm2cells.items():
+        # reordering_arr[i] should be the fd node corresponding to meshmode
+        # node i
+        #
+        # The jth meshmode cell corresponds to the fd_cell_order[j]th
+        # firedrake cell. If *nodeperm* is the permutation of local nodes
+        # applied to the *j*th meshmode cell, the firedrake node
+        # fd_cell_node[fd_cell_order[j]][k] corresponds to the
+        # mm_cell_node[j, nodeperm[k]]th meshmode node.
+        #
+        # Note that the permutation on the unit nodes may not be the
+        # same as the permutation on the barycentric coordinates (*perm*).
+        # Importantly, the permutation is derived from getting a flip
+        # matrix from the Firedrake unit nodes, not necessarily the meshmode
+        # unit nodes
+        #
+        flip_mat = get_simplex_element_flip_matrix(el_group.order,
+                                                   fd_unit_nodes,
+                                                   np.argsort(perm))
+        flip_mat = np.rint(flip_mat).astype(IntType)
+        fd_permuted_cell_node = np.matmul(fd_cell_node[fd_cell_order[cells]],
+                                          flip_mat.T)
+        mm2fd_node_mapping[cells] = fd_permuted_cell_node
+
+    assert np.size(np.unique(mm2fd_node_mapping)) == \
+        np.size(mm2fd_node_mapping), \
+        "A firedrake node in a 'DG' space got duplicated"
+    return FiredrakeConnection(discr,
+                               fspace,
+                               mm2fd_node_mapping,
+                               group_nr=group_nr)
 
 # }}}
 
diff --git a/test/test_firedrake_interop.py b/test/test_firedrake_interop.py
index 338bd812..06e6c661 100644
--- a/test/test_firedrake_interop.py
+++ b/test/test_firedrake_interop.py
@@ -39,8 +39,8 @@ from meshmode.dof_array import DOFArray
 from meshmode.mesh import BTAG_ALL, BTAG_REALLY_ALL, check_bc_coverage
 
 from meshmode.interop.firedrake import (
-    FromFiredrakeConnection, FromBoundaryFiredrakeConnection,
-    ToFiredrakeConnection, import_firedrake_mesh)
+    build_connection_from_firedrake, build_connection_to_firedrake,
+    import_firedrake_mesh)
 
 import pytest
 
@@ -168,7 +168,7 @@ def check_consistency(fdrake_fspace, discr, group_nr=0):
 
 def test_from_fd_consistency(ctx_factory, fdrake_mesh, fspace_degree):
     """
-    Check basic consistency with a FromFiredrakeConnection
+    Check basic consistency with a FiredrakeConnection built from firedrake
     """
     # make discretization from firedrake
     fdrake_fspace = FunctionSpace(fdrake_mesh, 'DG', fspace_degree)
@@ -177,7 +177,7 @@ def test_from_fd_consistency(ctx_factory, fdrake_mesh, fspace_degree):
     queue = cl.CommandQueue(cl_ctx)
     actx = PyOpenCLArrayContext(queue)
 
-    fdrake_connection = FromFiredrakeConnection(actx, fdrake_fspace)
+    fdrake_connection = build_connection_from_firedrake(actx, fdrake_fspace)
     discr = fdrake_connection.discr
     # Check consistency
     check_consistency(fdrake_fspace, discr)
@@ -192,7 +192,7 @@ def test_to_fd_consistency(ctx_factory, mm_mesh, fspace_degree):
 
     factory = InterpolatoryQuadratureSimplexGroupFactory(fspace_degree)
     discr = Discretization(actx, mm_mesh, factory)
-    fdrake_connection = ToFiredrakeConnection(discr)
+    fdrake_connection = build_connection_to_firedrake(discr)
     fdrake_fspace = fdrake_connection.firedrake_fspace()
     # Check consistency
     check_consistency(fdrake_fspace, discr)
@@ -353,7 +353,9 @@ def test_bdy_tags(square_or_cube_mesh, bdy_ids, coord_indices, coord_values,
 # }}}
 
 
-# TODO : Add test for ToFiredrakeConnection where group_nr != 0
+# TODO : Add test for FiredrakeConnection built from meshmode
+#        where group_nr != 0
+
 # {{{  Double check functions are being transported correctly
 
 
@@ -435,11 +437,12 @@ def test_from_fd_transfer(ctx_factory, fspace_degree,
         # make function space and build connection
         fdrake_fspace = FunctionSpace(fdrake_mesh, 'DG', fspace_degree)
         if only_convert_bdy:
-            fdrake_connection = FromBoundaryFiredrakeConnection(actx,
-                                                                fdrake_fspace,
-                                                                'on_boundary')
+            fdrake_connection = \
+                build_connection_from_firedrake(actx,
+                                                fdrake_fspace,
+                                                restrict_to_boundary='on_boundary')
         else:
-            fdrake_connection = FromFiredrakeConnection(actx, fdrake_fspace)
+            fdrake_connection = build_connection_from_firedrake(actx, fdrake_fspace)
         # get this for making functions in firedrake
         spatial_coord = SpatialCoordinate(fdrake_mesh)
 
@@ -525,7 +528,7 @@ def test_to_fd_transfer(ctx_factory, fspace_degree, mesh_name, mesh_pars, dim):
         factory = InterpolatoryQuadratureSimplexGroupFactory(fspace_degree)
         discr = Discretization(actx, mm_mesh, factory)
 
-        fdrake_connection = ToFiredrakeConnection(discr)
+        fdrake_connection = build_connection_to_firedrake(discr)
         fdrake_fspace = fdrake_connection.firedrake_fspace()
         spatial_coord = SpatialCoordinate(fdrake_fspace.mesh())
 
@@ -597,13 +600,14 @@ def test_from_fd_idempotency(ctx_factory,
     #
     # Otherwise, just continue as normal
     if only_convert_bdy:
-        fdrake_connection = FromBoundaryFiredrakeConnection(actx,
-                                                            fdrake_fspace,
-                                                            'on_boundary')
+        fdrake_connection = \
+            build_connection_from_firedrake(actx,
+                                            fdrake_fspace,
+                                            restrict_to_boundary='on_boundary')
         temp = fdrake_connection.from_firedrake(fdrake_unique, actx=actx)
         fdrake_unique = fdrake_connection.from_meshmode(temp)
     else:
-        fdrake_connection = FromFiredrakeConnection(actx, fdrake_fspace)
+        fdrake_connection = build_connection_from_firedrake(actx, fdrake_fspace)
 
     # Test for idempotency fd->mm->fd
     mm_field = fdrake_connection.from_firedrake(fdrake_unique, actx=actx)
@@ -643,7 +647,7 @@ def test_to_fd_idempotency(ctx_factory, mm_mesh, fspace_degree):
     # Make a function space and a function with unique values at each node
     factory = InterpolatoryQuadratureSimplexGroupFactory(fspace_degree)
     discr = Discretization(actx, mm_mesh, factory)
-    fdrake_connection = ToFiredrakeConnection(discr)
+    fdrake_connection = build_connection_to_firedrake(discr)
     fdrake_mesh = fdrake_connection.firedrake_fspace().mesh()
     dtype = fdrake_mesh.coordinates.dat.data.dtype
 
-- 
GitLab