diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 8a6f0464299e5cc0c3532abd6af2ab7f3908dfec..27cb9a032d064d47d57202aa11a5d7f75eb3fb8f 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -53,6 +53,11 @@ def partition_mesh(mesh, part_per_element, part_nr):
         is a :class:`meshmode.mesh.Mesh` that is a partition of mesh, and
         *part_to_global* is a :class:`numpy.ndarray` mapping element
         numbers on *part_mesh* to ones in *mesh*.
+
+    .. versionadded:: 2017.1
+
+    .. warning:: Interface is not final. Connectivity between elements
+        across groups needs to be added.
     """
     assert len(part_per_element) == mesh.nelements, (
         "part_per_element must have shape (mesh.nelements,)")
@@ -126,7 +131,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
         if group_nr not in skip_groups:
             mesh_group = mesh.groups[group_nr]
             new_mesh_groups.append(
-                SimplexElementGroup(mesh_group.order, new_indices[group_nr],
+                type(mesh_group)(mesh_group.order, new_indices[group_nr],
                     new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
 
     part_mesh = Mesh(new_vertices, new_mesh_groups)