From 8831b0d3357368034d19e9fcfe2e717d6beade0e Mon Sep 17 00:00:00 2001
From: benSepanski <ben_sepanski@alumni.baylor.edu>
Date: Thu, 25 Jun 2020 16:24:53 -0500
Subject: [PATCH] Redid mesh to be functional and use dmplex/cell_closure
 properly

---
 meshmode/interop/firedrake/mesh.py           | 462 +++++++++++++++----
 meshmode/interop/firedrake/reference_cell.py |   8 +-
 meshmode/mesh/processing.py                  |  61 ++-
 3 files changed, 422 insertions(+), 109 deletions(-)

diff --git a/meshmode/interop/firedrake/mesh.py b/meshmode/interop/firedrake/mesh.py
index d067fb27..7ca5ea8f 100644
--- a/meshmode/interop/firedrake/mesh.py
+++ b/meshmode/interop/firedrake/mesh.py
@@ -22,116 +22,107 @@ THE SOFTWARE.
 
 __doc__ = """
 .. autofunction:: get_firedrake_nodal_adjacency_group
-.. autofunction:: get_firedrake_boundary_tags
+.. autofunction:: get_firedrake_vertex_indices
 """
 
 from warnings import warn  # noqa
 import numpy as np
+import six
 
 
 # {{{ functions to extract information from Mesh Topology
 
-def get_firedrake_nodal_adjacency_group(fdrake_mesh, cells_to_use=None):
+
+def _get_firedrake_nodal_info(fdrake_mesh_topology):
+    # FIXME: do docs
     """
-    Create a nodal adjacency object
-    representing the nodal adjacency of :param:`fdrake_mesh`
-
-    :param fdrake_mesh: A firedrake mesh (:class:`MeshTopology`
-        or :class:`MeshGeometry`)
-    :param cells_to_use: Either
-
-            * *None*, in which case this argument is is ignored
-            * A numpy array of firedrake cell indices, in which case
-              any cell with indices not in the array :param:`cells_to_use`
-              is ignored.
-              This induces a new order on the cells.
-              The *i*th element in the returned :class:`NodalAdjacency`
-              object corresponds to the ``cells_to_use[i]``th cell
-              in the firedrake mesh.
-
-        This feature has been used when only part of the mesh
-        needs to be converted, since firedrake has no concept
-        of a "sub-mesh".
-
-    :return: A :class:`meshmode.mesh.NodalAdjacency` instance
-        representing the nodal adjacency of :param:`fdrake_mesh`
+    Get nodal adjacency, and vertex indices from a firedrake mesh topology
     """
-    mesh_topology = fdrake_mesh.topology
+    top = fdrake_mesh_topology.topology
+
+    # If you don't understand dmplex, look at the PETSc reference
+    # here: https://cse.buffalo.edu/~knepley/classes/caam519/CSBook.pdf
+    # used to get topology info
     # TODO... not sure how to get around the private access
-    plex = mesh_topology._plex
+    plex = top._plex
 
-    # dmplex cell Start/end and vertex Start/end.
+    # Get range of dmplex ids for cells, facets, and vertices
     c_start, c_end = plex.getHeightStratum(0)
+    f_start, f_end = plex.getHeightStratum(1)
     v_start, v_end = plex.getDepthStratum(0)
 
-    # TODO... not sure how to get around the private access
-    # This maps the dmplex index of a cell to its firedrake index
-    to_fd_id = np.vectorize(mesh_topology._cell_numbering.getOffset)(
-        np.arange(c_start, c_end, dtype=np.int32))
-
-    element_to_neighbors = {}
-    verts_checked = set()  # dmplex ids of vertex checked
-
-    # If using all cells, loop over them all
-    if cells_to_use is None:
-        range_ = range(c_start, c_end)
-    # Otherwise, just the ones you're using
-    else:
-        assert isinstance(cells_to_use, np.ndarray)
-        assert np.size(cells_to_use) == np.size(np.unique(cells_to_use)), \
-            "cells_to_use must have unique values"
-        assert len(np.shape(cells_to_use)) == 1 and len(cells_to_use) > 0
-        isin = np.isin(to_fd_id, cells_to_use)
-        range_ = np.arange(c_start, c_end, dtype=np.int32)[isin]
-
-    # For each cell
-    for cell_id in range_:
-        # For each vertex touching the cell (that haven't already seen)
-        for vert_id in plex.getTransitiveClosure(cell_id)[0]:
-            if v_start <= vert_id < v_end and vert_id not in verts_checked:
-                verts_checked.add(vert_id)
-                cells = []
-                # Record all cells touching that vertex
-                support = plex.getTransitiveClosure(vert_id, useCone=False)[0]
-                for other_cell_id in support:
-                    if c_start <= other_cell_id < c_end:
-                        cells.append(to_fd_id[other_cell_id - c_start])
-
-                # If only using some cells, clean out extraneous ones
-                # and relabel them to new id
-                cells = set(cells)
-                if cells_to_use is not None:
-                    cells = set([cells_to_use[fd_ndx] for fd_ndx in cells
-                                 if fd_ndx in cells_to_use])
-
-                # mark cells as neighbors
-                for cell_one in cells:
-                    element_to_neighbors.setdefault(cell_one, set())
-                    element_to_neighbors[cell_one] |= cells
-
-    # Count the number of cells
-    if cells_to_use is None:
-        nelements = mesh_topology.num_cells()
-    else:
-        nelements = cells_to_use.shape[0]
-
-    # Create neighbors_starts and neighbors
+    # TODO... not sure how to get around the private accesses
+    # Maps dmplex cell id -> firedrake cell index
+    def cell_id_dmp_to_fd(ndx):
+        return top._cell_numbering.getOffset(ndx)
+
+    # Maps dmplex vert id -> firedrake vert index
+    def vert_id_dmp_to_fd(ndx):
+        return top._vertex_numbering.getOffset(ndx)
+
+    # FIXME : Is this the right integer type?
+    # We will fill in the values as we go
+    vertex_indices = -np.ones((top.num_cells(), top.ufl_cell().num_vertices()),
+                              dtype=np.int32)
+    # This will map fd cell ndx -> list of fd cell indices which share a vertex
+    cell_to_nodal_neighbors = {}
+    # This will map dmplex facet id -> list of adjacent
+    #                                  (fd cell ndx, firedrake local fac num)
+    facet_to_cells = {}
+    # This will map dmplex vert id -> list of fd cell
+    #                                 indices which touch this vertex,
+    # Primarily used to construct cell_to_nodal_neighbors
+    vert_to_cells = {}
+
+    # Loop through each cell (cell closure is all the dmplex ids for any
+    # verts, faces, etc. associated with the cell)
+    for fd_cell_ndx, closure_dmp_ids in enumerate(top.cell_closure):
+        # Store the vertex indices
+        dmp_verts = closure_dmp_ids[np.logical_and(v_start <= closure_dmp_ids,
+                                                   closure_dmp_ids < v_end)]
+        fd_verts = np.array([vert_id_dmp_to_fd(dmp_vert)
+                             for dmp_vert in dmp_verts])
+        vertex_indices[fd_cell_ndx][:] = fd_verts[:]
+
+        # Record this cell as touching the facet and remember its local
+        # facet number (the order it appears)
+        dmp_fac_ids = closure_dmp_ids[np.logical_and(f_start <= closure_dmp_ids,
+                                                     closure_dmp_ids < f_end)]
+        for loc_fac_nr, dmp_fac_id in enumerate(dmp_fac_ids):
+            # make sure there is a list to append to and append
+            facet_to_cells.setdefault(dmp_fac_id, [])
+            facet_to_cells[dmp_fac_id].append((fd_cell_ndx, loc_fac_nr))
+
+        # Record this vertex as touching the cell, and mark this cell
+        # as nodally adjacent (in cell_to_nodal_neighbors) to any
+        # cells already documented as touching this cell
+        cell_to_nodal_neighbors[fd_cell_ndx] = []
+        for dmp_vert_id in dmp_verts:
+            vert_to_cells.setdefault(dmp_vert_id, [])
+            for other_cell_ndx in vert_to_cells[dmp_vert_id]:
+                cell_to_nodal_neighbors[fd_cell_ndx].append(other_cell_ndx)
+                cell_to_nodal_neighbors[other_cell_ndx].append(fd_cell_ndx)
+            vert_to_cells[dmp_vert_id].append(fd_cell_ndx)
+
+    # Next go ahead and compute nodal adjacency by creating
+    # neighbors and neighbor_starts as specified by :class:`NodalAdjacency`
     neighbors = []
     # FIXME : Is this the right integer type to choose?
-    neighbors_starts = np.zeros(nelements + 1, dtype=np.int32)
-    for iel in range(len(element_to_neighbors)):
-        elt_neighbors = element_to_neighbors[iel]
-        neighbors += list(elt_neighbors)
+    neighbors_starts = np.zeros(top.num_cells() + 1, dtype=np.int32)
+    for iel in range(len(cell_to_nodal_neighbors)):
+        neighbors += cell_to_nodal_neighbors[iel]
         neighbors_starts[iel+1] = len(neighbors)
 
     neighbors = np.array(neighbors, dtype=np.int32)
 
     from meshmode.mesh import NodalAdjacency
-    return NodalAdjacency(neighbors_starts=neighbors_starts,
-                          neighbors=neighbors)
+    nodal_adjacency = NodalAdjacency(neighbors_starts=neighbors_starts,
+                                     neighbors=neighbors)
+
+    return (vertex_indices, nodal_adjacency)
 
 
-def get_firedrake_boundary_tags(fdrake_mesh):
+def _get_firedrake_boundary_tags(fdrake_mesh):
     """
     Return a tuple of bdy tags as requested in
     the construction of a :mod:`meshmode` :class:`Mesh`
@@ -155,4 +146,303 @@ def get_firedrake_boundary_tags(fdrake_mesh):
 
     return tuple(bdy_tags)
 
+
+def _get_firedrake_facial_adjacency_groups(fdrake_mesh):
+    # FIXME: do docs
+    top = fdrake_mesh.topology
+    # We only need one group
+    # for interconnectivity and one for boundary connectivity.
+    # The tricky part is moving from firedrake local facet numbering
+    # (ordered lexicographically by the vertex excluded from the face)
+    # and meshmode's facet ordering: obtained from a simplex element
+    # group
+    from meshmode.mesh import SimplexElementGroup
+    mm_simp_group = SimplexElementGroup(1, None, None,
+                                        dim=top.cell_dimension())
+    mm_face_vertex_indices = mm_simp_group.face_vertex_indices()
+    # map firedrake local face number to meshmode local face number
+    fd_loc_fac_nr_to_mm = {}
+    # Figure out which vertex is excluded to get the corresponding
+    # firedrake local index
+    for mm_loc_fac_nr, face in enumerate(mm_face_vertex_indices):
+        for fd_loc_fac_nr in range(top.ufl_cell().num_vertices()):
+            if fd_loc_fac_nr not in face:
+                fd_loc_fac_nr_to_mm[fd_loc_fac_nr] = mm_loc_fac_nr
+                break
+
+    # First do the interconnectivity group
+
+    # Get the firedrake cells associated to each interior facet
+    int_facet_cell = top.interior_facets.facet_cell
+    # Get the firedrake local facet numbers and map them to the
+    # meshmode local facet numbers
+    int_fac_loc_nr = top.interior_facets.local_facet_dat.data
+    int_fac_loc_nr = \
+        np.array([[fd_loc_fac_nr_to_mm[fac_nr] for fac_nr in fac_nrs]
+                  for fac_nrs in int_fac_loc_nr])
+    # elements neighbors element_faces neighbor_faces are as required
+    # for a :class:`FacialAdjacencyGroup`.
+    from meshmode.mesh import Mesh
+
+    int_elements = int_facet_cell.flatten()
+    int_neighbors = np.concatenate((int_facet_cell[:, 1], int_facet_cell[:, 0]))
+    int_element_faces = int_fac_loc_nr.flatten().astype(Mesh.face_id_dtype)
+    int_neighbor_faces = np.concatenate((int_fac_loc_nr[:, 1],
+                                        int_fac_loc_nr[:, 0]))
+    int_neighbor_faces = int_neighbor_faces.astype(Mesh.face_id_dtype)
+
+    from meshmode.mesh import FacialAdjacencyGroup
+    interconnectivity_grp = FacialAdjacencyGroup(igroup=0, ineighbor_group=0,
+                                                 elements=int_elements,
+                                                 neighbors=int_neighbors,
+                                                 element_faces=int_element_faces,
+                                                 neighbor_faces=int_neighbor_faces)
+
+    # Now look at exterior facets
+
+    # We can get the elements directly from exterior facets
+    ext_elements = top.exterior_facets.facet_cell.flatten()
+
+    ext_element_faces = top.exterior_facets.local_facet_dat.data
+    ext_element_faces = ext_element_faces.astype(Mesh.face_id_dtype)
+    ext_neighbor_faces = np.zeros(ext_element_faces.shape, dtype=np.int32)
+    ext_neighbor_faces = ext_neighbor_faces.astype(Mesh.face_id_dtype)
+
+    # Now we need to tag the boundary
+    bdy_tags = _get_firedrake_boundary_tags(top)
+    boundary_tag_to_index = {bdy_tag: i for i, bdy_tag in enumerate(bdy_tags)}
+
+    def boundary_tag_bit(boundary_tag):
+        try:
+            return 1 << boundary_tag_to_index[boundary_tag]
+        except KeyError:
+            raise 0
+
+    from meshmode.mesh import BTAG_ALL, BTAG_REALLY_ALL
+    ext_neighbors = np.zeros(ext_elements.shape, dtype=np.int32)
+    for ifac, marker in enumerate(top.exterior_facets.markers):
+        ext_neighbors[ifac] = boundary_tag_bit(BTAG_ALL) \
+                              | boundary_tag_bit(BTAG_REALLY_ALL) \
+                              | boundary_tag_bit(marker)
+
+    exterior_grp = FacialAdjacencyGroup(igroup=0, ineighbor=None,
+                                        elements=ext_elements,
+                                        element_faces=ext_element_faces,
+                                        neighbors=ext_neighbors,
+                                        neighbor_faces=ext_neighbor_faces)
+
+    return [{0: interconnectivity_grp, None: exterior_grp}]
+
 # }}}
+
+
+# {{{ Orientation computation
+
+def _get_firedrake_orientations(fdrake_mesh, unflipped_group, vertices,
+                                normals=None, no_normals_warn=True):
+    # FIXME : Fix docs
+    """
+    Return the orientations of the mesh elements:
+    an array, the *i*th element is > 0 if the *ith* element
+    is positively oriented, < 0 if negatively oriented.
+    Mesh must have co-dimension 0 or 1.
+
+    :param normals: _Only_ used if :param:`mesh` is a 1-surface
+        embedded in 2-space. In this case,
+        - If *None* then
+          all elements are assumed to be positively oriented.
+        - Else, should be a list/array whose *i*th entry
+          is the normal for the *i*th element (*i*th
+          in :param:`mesh`*.coordinate.function_space()*'s
+          :attribute:`cell_node_list`)
+
+    :param no_normals_warn: If *True*, raises a warning
+        if :param:`mesh` is a 1-surface embedded in 2-space
+        and :param:`normals` is *None*.
+    """
+    # compute orientations
+    tdim = fdrake_mesh.topological_dimension()
+    gdim = fdrake_mesh.geometric_dimension()
+
+    orient = None
+    if gdim == tdim:
+        # We use :mod:`meshmode` to check our orientations
+        from meshmode.mesh.processing import \
+            find_volume_mesh_element_group_orientation
+
+        orient = find_volume_mesh_element_group_orientation(vertices,
+                                                            unflipped_group)
+
+    if tdim == 1 and gdim == 2:
+        # In this case we have a 1-surface embedded in 2-space
+        orient = np.ones(fdrake_mesh.num_cells())
+        if normals:
+            for i, (normal, vertices) in enumerate(zip(np.array(normals),
+                                                       vertices)):
+                if np.cross(normal, vertices) < 0:
+                    orient[i] = -1.0
+        elif no_normals_warn:
+            warn("Assuming all elements are positively-oriented.")
+
+    elif tdim == 2 and gdim == 3:
+        # In this case we have a 2-surface embedded in 3-space
+        orient = fdrake_mesh.cell_orientations().dat.data
+        r"""
+            Convert (0 \implies negative, 1 \implies positive) to
+            (-1 \implies negative, 1 \implies positive)
+        """
+        orient *= 2
+        orient -= np.ones(orient.shape, dtype=orient.dtype)
+    #Make sure the mesh fell into one of the cases
+    """
+    NOTE : This should be guaranteed by previous checks,
+           but is here anyway in case of future development.
+    """
+    assert orient is not None, "something went wrong, contact the developer"
+    return orient
+
+# }}}
+
+
+# {{{ Mesh conversion
+
+def import_firedrake_mesh(fdrake_mesh):
+    # FIXME : docs
+    # Type validation
+    from firedrake.mesh import MeshGeometry
+    if not isinstance(fdrake_mesh, MeshGeometry):
+        raise TypeError(":param:`fdrake_mesh_topology` must be a "
+                        ":mod:`firedrake` :class:`MeshGeometry`, "
+                        "not %s." % type(fdrake_mesh))
+    assert fdrake_mesh.ufl_cell().is_simplex(), "Mesh must use simplex cells"
+    gdim = fdrake_mesh.geometric_dimension()
+    tdim = fdrake_mesh.topological_dimension()
+    assert gdim - tdim in [0, 1], "Mesh co-dimension must be 0 or 1"
+    fdrake_mesh.init()
+
+    # Get all the nodal information we can from the topology
+    bdy_tags = _get_firedrake_boundary_tags(fdrake_mesh)
+    vertex_indices, nodal_adjacency = _get_firedrake_nodal_info(fdrake_mesh)
+
+    # Grab the mesh reference element and cell dimension
+    coord_finat_elt = fdrake_mesh.coordinates.function_space().finat_element
+    cell_dim = fdrake_mesh.cell_dimension()
+
+    # Get finat unit nodes and map them onto the meshmode reference simplex
+    from meshmode.interop.firedrake.reference_cell import (
+        get_affine_reference_simplex_mapping, get_finat_element_unit_nodes)
+    finat_unit_nodes = get_finat_element_unit_nodes(coord_finat_elt)
+    fd_ref_to_mm = get_affine_reference_simplex_mapping(cell_dim, True)
+    finat_unit_nodes = fd_ref_to_mm(finat_unit_nodes)
+
+    # Now grab the nodes
+    coords = fdrake_mesh.coordinates
+    cell_node_list = coords.function_space().cell_node_list
+    nodes = np.real(coords.dat.data[cell_node_list])
+    # Add extra dim in 1D so that have [nelements][nunit_nodes][dim]
+    if len(nodes.shape) == 2:
+        nodes = np.reshape(nodes, nodes.shape + (1,))
+    # Now we want the nodes to actually have shape [dim][nelements][nunit_nodes]
+    nodes = np.transpose(nodes, (2, 0, 1))
+
+    # make a group (possibly with some elements that need to be flipped)
+    from meshmode.mesh import SimplexElementGroup
+    unflipped_group = SimplexElementGroup(coord_finat_elt.degree,
+                                          vertex_indices,
+                                          nodes,
+                                          dim=cell_dim,
+                                          unit_nodes=finat_unit_nodes)
+
+    # Next get the vertices (we'll need these for the orientations)
+
+    coord_finat = fdrake_mesh.coordinates.function_space().finat_element
+    # unit_vertex_indices are the element-local indices of the nodes
+    # which coincide with the vertices, i.e. for element *i*,
+    # vertex 0's coordinates would be nodes[i][unit_vertex_indices[0]].
+    # This assumes each vertex has some node which coincides with it...
+    # which is normally fine to assume for firedrake meshes.
+    unit_vertex_indices = []
+    # iterate through the dofs associated to each vertex on the
+    # reference element
+    for _, dofs in sorted(six.iteritems(coord_finat.entity_dofs()[0])):
+        assert len(dofs) == 1, \
+            "The function space of the mesh coordinates must have" \
+            " exactly one degree of freedom associated with " \
+            " each vertex in order to determine vertex coordinates"
+        dof, = dofs
+        unit_vertex_indices.append(dof)
+
+    # Now get the vertex coordinates
+    vertices = {}
+    for icell, cell_vertex_indices in enumerate(vertex_indices):
+        for local_vert_id, global_vert_id in enumerate(cell_vertex_indices):
+            if global_vert_id in vertices:
+                continue
+            local_node_nr = unit_vertex_indices[local_vert_id]
+            vertices[global_vert_id] = nodes[:, icell, local_node_nr]
+    # Stuff the vertices in a *(dim, nvertices)*-shaped numpy array
+    vertices = np.array([vertices[i] for i in range(len(vertices))]).T
+
+    # Use the vertices to compute the orientations and flip the group
+    # FIXME : Allow for passing in normals/no normals warn
+    orient = _get_firedrake_orientations(fdrake_mesh, unflipped_group, vertices)
+    from meshmode.mesh.processing import flip_simplex_element_group
+    group = flip_simplex_element_group(vertices, unflipped_group, orient < 0)
+
+    # Now, any flipped element had its 0 vertex and 1 vertex exchanged.
+    # This changes the local facet nr, so we need to create and then
+    # fix our facial adjacency groups. To do that, we need to figure
+    # out which local facet numbers switched.
+    from meshmode.mesh import SimplexElementGroup
+    mm_simp_group = SimplexElementGroup(1, None, None,
+                                        dim=fdrake_mesh.cell_dimension())
+    face_vertex_indices = mm_simp_group.face_vertex_indices()
+    # face indices of the faces not containing vertex 0 and not
+    # containing vertex 1, respectively
+    no_zero_face_ndx, no_one_face_ndx = None, None
+    for iface, face in enumerate(face_vertex_indices):
+        if 0 not in face:
+            no_zero_face_ndx = iface
+        elif 1 not in face:
+            no_one_face_ndx = iface
+
+    unflipped_facial_adjacency_groups = \
+        _get_firedrake_facial_adjacency_groups(fdrake_mesh)
+
+    def flip_local_face_indices(arr, elements):
+        arr = np.copy(arr)
+        to_no_one = np.logical_and(orient[elements] < 0, arr == no_zero_face_ndx)
+        to_no_zero = np.logical_and(orient[elements] < 0, arr == no_one_face_ndx)
+        arr[to_no_one], arr[to_no_zero] = no_one_face_ndx, no_zero_face_ndx
+        return arr
+
+    facial_adjacency_groups = []
+    from meshmode.mesh import FacialAdjacencyGroup
+    for igroup, fagrps in enumerate(unflipped_facial_adjacency_groups):
+        facial_adjacency_groups.append({})
+        for ineighbor_group, fagrp in six.iteritems(fagrps):
+            new_element_faces = flip_local_face_indices(fagrp.element_faces,
+                                                        fagrp.elements)
+            new_neighbor_faces = flip_local_face_indices(fagrp.neighbor_faces,
+                                                         fagrp.neighbors)
+            new_fagrp = FacialAdjacencyGroup(igroup=igroup,
+                                             ineighbor_group=ineighbor_group,
+                                             elements=fagrp.elements,
+                                             element_faces=new_element_faces,
+                                             neighbors=fagrp.neighbors,
+                                             neighbor_faces=new_neighbor_faces)
+        facial_adjacency_groups[igroup][ineighbor_group] = new_fagrp
+
+    from meshmode.mesh import Mesh
+    return Mesh(vertices, [group],
+                boundary_tags=bdy_tags,
+                nodal_adjacency=nodal_adjacency,
+                facial_adjacency_groups=facial_adjacency_groups)
+
+# }}}
+
+
+from firedrake import UnitSquareMesh
+m = UnitSquareMesh(10, 10)
+m.init()
+mm_mesh = import_firedrake_mesh(m)
diff --git a/meshmode/interop/firedrake/reference_cell.py b/meshmode/interop/firedrake/reference_cell.py
index 100f426f..b8509582 100644
--- a/meshmode/interop/firedrake/reference_cell.py
+++ b/meshmode/interop/firedrake/reference_cell.py
@@ -32,7 +32,7 @@ __doc__ = """
 
 # {{{ Map between reference simplices
 
-def get_affine_reference_simplex_mapping(self, spat_dim, firedrake_to_meshmode=True):
+def get_affine_reference_simplex_mapping(spat_dim, firedrake_to_meshmode=True):
     """
     Returns a function which takes a numpy array points
     on one reference cell and maps each
@@ -115,12 +115,16 @@ def get_finat_element_unit_nodes(finat_element):
     (equivalently, FInAT/FIAT's) reference coordinates
 
     :param finat_element: A :class:`finat.finiteelementbase.FiniteElementBase`
-        instance (i.e. a firedrake function space's reference element)
+        instance (i.e. a firedrake function space's reference element).
+        The refernce element of the finat element *MUST* be a simplex
     :return: A numpy array of shape *(dim, nunit_nodes)* holding the unit
              nodes used by this element. *dim* is the dimension spanned
              by the finat element's reference element
              (see its ``cell`` attribute)
     """
+    from FIAT.reference_element import Simplex
+    assert isinstance(finat_element.cell, Simplex), \
+        "Reference element of the finat element MUST be a simplex"
     # point evaluators is a list of functions *p_0,...,p_{n-1}*.
     # *p_i(f)* evaluates function *f* at node *i* (stored as a tuple),
     # so to recover node *i* we need to evaluate *p_i* at the identity
diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 97b9bbda..36e2505e 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -363,36 +363,36 @@ def test_volume_mesh_element_orientations(mesh):
 
 # {{{ flips
 
-def flip_simplex_element_group(vertices, grp, grp_flip_flags):
-    from modepy.tools import barycentric_to_unit, unit_to_barycentric
-
-    from meshmode.mesh import SimplexElementGroup
-
-    if not isinstance(grp, SimplexElementGroup):
-        raise NotImplementedError("flips only supported on "
-                "exclusively SimplexElementGroup-based meshes")
-
-    # Swap the first two vertices on elements to be flipped.
 
-    new_vertex_indices = grp.vertex_indices.copy()
-    new_vertex_indices[grp_flip_flags, 0] \
-            = grp.vertex_indices[grp_flip_flags, 1]
-    new_vertex_indices[grp_flip_flags, 1] \
-            = grp.vertex_indices[grp_flip_flags, 0]
-
-    # Generate a resampling matrix that corresponds to the
-    # first two barycentric coordinates being swapped.
+def get_simplex_element_flip_matrix(order, unit_nodes):
+    """
+    Generate a resampling matrix that corresponds to the
+    first two barycentric coordinates being swapped.
+
+    :param order: The order of the function space on the simplex,
+                 (see second argument in
+                  :fun:`modepy.simplex_best_available_basis`)
+    :param unit_nodes: A np array of unit nodes with shape
+                       *(dim, nunit_nodes)*
+
+    :return: A numpy array of shape *(dim, dim)* which, when applied
+             to the matrix of nodes (shaped *(dim, nunit_nodes)*)
+             corresponds to the first two barycentric coordinates
+             being swapped
+    """
+    from modepy.tools import barycentric_to_unit, unit_to_barycentric
 
-    bary_unit_nodes = unit_to_barycentric(grp.unit_nodes)
+    bary_unit_nodes = unit_to_barycentric(unit_nodes)
 
     flipped_bary_unit_nodes = bary_unit_nodes.copy()
     flipped_bary_unit_nodes[0, :] = bary_unit_nodes[1, :]
     flipped_bary_unit_nodes[1, :] = bary_unit_nodes[0, :]
     flipped_unit_nodes = barycentric_to_unit(flipped_bary_unit_nodes)
 
+    dim = unit_nodes.shape[0]
     flip_matrix = mp.resampling_matrix(
-            mp.simplex_best_available_basis(grp.dim, grp.order),
-            flipped_unit_nodes, grp.unit_nodes)
+            mp.simplex_best_available_basis(dim, order),
+            flipped_unit_nodes, unit_nodes)
 
     flip_matrix[np.abs(flip_matrix) < 1e-15] = 0
 
@@ -401,7 +401,26 @@ def flip_simplex_element_group(vertices, grp, grp_flip_flags):
             np.dot(flip_matrix, flip_matrix)
             - np.eye(len(flip_matrix))) < 1e-13
 
+    return flip_matrix
+
+
+def flip_simplex_element_group(vertices, grp, grp_flip_flags):
+    from meshmode.mesh import SimplexElementGroup
+
+    if not isinstance(grp, SimplexElementGroup):
+        raise NotImplementedError("flips only supported on "
+                "exclusively SimplexElementGroup-based meshes")
+
+    # Swap the first two vertices on elements to be flipped.
+
+    new_vertex_indices = grp.vertex_indices.copy()
+    new_vertex_indices[grp_flip_flags, 0] \
+            = grp.vertex_indices[grp_flip_flags, 1]
+    new_vertex_indices[grp_flip_flags, 1] \
+            = grp.vertex_indices[grp_flip_flags, 0]
+
     # Apply the flip matrix to the nodes.
+    flip_matrix = get_simplex_element_flip_matrix(grp.order, grp.unit_nodes)
     new_nodes = grp.nodes.copy()
     new_nodes[:, grp_flip_flags] = np.einsum(
             "ij,dej->dei",
-- 
GitLab