diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index ff74ac81abdc6e414bd5296b27e8545a047acd25..27d3ef6e39b9e8fc912ca0109332706bb215b66e 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -39,6 +39,57 @@ __doc__ = """
 .. autofunction:: affine_map
 """
 
+def partition_mesh(mesh, part_per_element, part_nr):
+    """
+    :arg mesh: A :class:`meshmode.mesh.Mesh` to be partitioned.
+    :arg part_per_element: A :class:`numpy.ndarray` containing one
+        integer per element of *mesh* indicating which part of the
+        partitioned mesh the element is to become a part of.
+    :arg part_nr: The part number of the mesh to return.
+
+    :returns: A tuple ``(part_mesh, part_to_global)``, where *part_mesh*
+        is a :class:`meshmode.mesh.Mesh` that is a partition of mesh, and
+        *part_to_global* is a :class:`numpy.ndarray` mapping element
+        numbers on *part_mesh* to ones in *mesh*.
+    """
+
+    queried_elems = np.where(part_per_element == part_nr)[0]
+
+    '''
+    Here I assume that all elements are taken from the 0th group.
+    Will there be more groups? How do I handle this?
+    My guess is that len(part_per_element) will be equal to the sum of the
+    number of elements in each group. And that the first element of groups[1]
+    comes just after the last element in groups[0].
+    '''
+    mesh_group = mesh.groups[0]
+
+    new_vertex_indices = mesh_group.vertex_indices[queried_elems]
+
+    # A sorted np.array of vertex indices we need in our new mesh (without duplicates).
+    required_vertex_indices = np.array(list(set(new_vertex_indices.ravel())))
+
+    new_vertices = np.zeros((mesh.ambient_dim, len(required_vertex_indices)))
+    for dim in range(mesh.ambient_dim):
+        new_vertices[dim] = mesh.vertices[dim][required_vertex_indices]
+
+    # We need to update our indices to be in range [0, len(required_vertex_indices)].
+    for i in range(len(new_vertex_indices)):
+        for j in range(len(new_vertex_indices[0])):
+            new_vertex_indices[i, j] = np.where(required_vertex_indices == new_vertex_indices[i, j])[0]
+
+    new_nodes = np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes))
+    for i in range(mesh.ambient_dim):
+        for j in range(len(queried_elems)):
+            new_nodes[i, j, :] = mesh_group.nodes[i, queried_elems[j], :]
+
+    from meshmode.mesh import MeshElementGroup, Mesh
+
+    mesh_element_group = MeshElementGroup(mesh_group.order, new_vertex_indices, new_nodes, element_nr_base=mesh_group.element_nr_base, node_nr_base=mesh_group.node_nr_base, unit_nodes=mesh_group.unit_nodes, dim=mesh_group.dim)
+
+    part_mesh = Mesh(new_vertices, [mesh_element_group])
+
+    return (part_mesh, queried_elems)
 
 # {{{ orientations