diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index ff74ac81abdc6e414bd5296b27e8545a047acd25..37e4ac264c28f509623465c1e967fc909c070113 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -31,6 +31,7 @@ import modepy as mp
 
 
 __doc__ = """
+.. autofunction:: partition_mesh
 .. autofunction:: find_volume_mesh_element_orientations
 .. autofunction:: perform_flips
 .. autofunction:: find_bounding_box
@@ -40,6 +41,109 @@ __doc__ = """
 """
 
 
+def partition_mesh(mesh, part_per_element, part_nr):
+    """
+    :arg mesh: A :class:`meshmode.mesh.Mesh` to be partitioned.
+    :arg part_per_element: A :class:`numpy.ndarray` containing one
+        integer per element of *mesh* indicating which part of the
+        partitioned mesh the element is to become a part of.
+    :arg part_nr: The part number of the mesh to return.
+
+    :returns: A tuple ``(part_mesh, part_to_global)``, where *part_mesh*
+        is a :class:`meshmode.mesh.Mesh` that is a partition of mesh, and
+        *part_to_global* is a :class:`numpy.ndarray` mapping element
+        numbers on *part_mesh* to ones in *mesh*.
+
+    .. versionadded:: 2017.1
+
+    .. warning:: Interface is not final. Connectivity between elements
+        across groups needs to be added.
+    """
+    assert len(part_per_element) == mesh.nelements, (
+        "part_per_element must have shape (mesh.nelements,)")
+
+    # Contains the indices of the elements requested.
+    queried_elems = np.where(np.array(part_per_element) == part_nr)[0]
+
+    num_groups = len(mesh.groups)
+    new_indices = []
+    new_nodes = []
+
+    # The set of vertex indices we need.
+    # NOTE: There are two methods for producing required_indices.
+    #   Optimizations may come from further exploring these options.
+    #index_set = np.array([], dtype=int)
+    index_sets = np.array([], dtype=set)
+
+    skip_groups = []
+    num_prev_elems = 0
+    start_idx = 0
+    for group_nr in range(num_groups):
+        mesh_group = mesh.groups[group_nr]
+
+        # Find the index of first element in the next group
+        end_idx = len(queried_elems)
+        for idx in range(start_idx, len(queried_elems)):
+            if queried_elems[idx] - num_prev_elems >= mesh_group.nelements:
+                end_idx = idx
+                break
+
+        if start_idx == end_idx:
+            skip_groups.append(group_nr)
+            new_indices.append(np.array([]))
+            new_nodes.append(np.array([]))
+            num_prev_elems += mesh_group.nelements
+            continue
+
+        elems = queried_elems[start_idx:end_idx] - num_prev_elems
+        new_indices.append(mesh_group.vertex_indices[elems])
+
+        new_nodes.append(
+            np.zeros(
+                (mesh.ambient_dim, end_idx - start_idx, mesh_group.nunit_nodes)))
+        for i in range(mesh.ambient_dim):
+            for j in range(start_idx, end_idx):
+                elems = queried_elems[j] - num_prev_elems
+                new_idx = j - start_idx
+                new_nodes[group_nr][i, new_idx, :] = mesh_group.nodes[i, elems, :]
+
+        #index_set = np.append(index_set, new_indices[group_nr].ravel())
+        index_sets = np.append(index_sets, set(new_indices[group_nr].ravel()))
+
+        num_prev_elems += mesh_group.nelements
+        start_idx = end_idx
+
+    # A sorted np.array of vertex indices we need (without duplicates).
+    #required_indices = np.unique(np.sort(index_set))
+    required_indices = np.array(list(set.union(*index_sets)))
+
+    new_vertices = np.zeros((mesh.ambient_dim, len(required_indices)))
+    for dim in range(mesh.ambient_dim):
+        new_vertices[dim] = mesh.vertices[dim][required_indices]
+
+    # Our indices need to be in range [0, len(mesh.nelements)].
+    for group_nr in range(num_groups):
+        if group_nr not in skip_groups:
+            for i in range(len(new_indices[group_nr])):
+                for j in range(len(new_indices[group_nr][0])):
+                    original_index = new_indices[group_nr][i, j]
+                    new_indices[group_nr][i, j] = np.where(
+                        required_indices == original_index)[0]
+
+    new_mesh_groups = []
+    for group_nr in range(num_groups):
+        if group_nr not in skip_groups:
+            mesh_group = mesh.groups[group_nr]
+            new_mesh_groups.append(
+                type(mesh_group)(mesh_group.order, new_indices[group_nr],
+                    new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
+
+    from meshmode.mesh import Mesh
+    part_mesh = Mesh(new_vertices, new_mesh_groups)
+
+    return (part_mesh, queried_elems)
+
+
 # {{{ orientations
 
 def find_volume_mesh_element_group_orientation(vertices, grp):
diff --git a/requirements.txt b/requirements.txt
index 6ed102e6d1309b0e65dcfec56e15aac08448a7d4..f5bb34647a3c1fe4e8989462447a78416879ba4d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,3 +9,6 @@ git+https://github.com/inducer/loopy.git
 git+https://github.com/inducer/boxtree.git
 git+https://github.com/inducer/sumpy.git
 git+https://github.com/inducer/pytential.git
+
+# requires pymetis for tests for partition_mesh
+git+https://github.com/inducer/pymetis.git
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 3dabe0cfacde742dd4085293a9206c5a97ae6b7c..a793743ce6ff50cc67aea4b31225a1b9e4df4e3d 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -48,6 +48,55 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+# {{{ partition_mesh
+
+def test_partition_torus_mesh():
+    from meshmode.mesh.generation import generate_torus
+    my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2)
+
+    part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0])
+
+    from meshmode.mesh.processing import partition_mesh
+    (part_mesh0, _) = partition_mesh(my_mesh, part_per_element, 0)
+    (part_mesh1, _) = partition_mesh(my_mesh, part_per_element, 1)
+    (part_mesh2, _) = partition_mesh(my_mesh, part_per_element, 2)
+
+    assert part_mesh0.nelements == 2
+    assert part_mesh1.nelements == 4
+    assert part_mesh2.nelements == 2
+
+
+def test_partition_boxes_mesh():
+    n = 5
+    num_parts = 7
+    from meshmode.mesh.generation import generate_regular_rect_mesh
+    mesh1 = generate_regular_rect_mesh(a=(0, 0, 0), b=(1, 1, 1), n=(n, n, n))
+    mesh2 = generate_regular_rect_mesh(a=(2, 2, 2), b=(3, 3, 3), n=(n, n, n))
+
+    from meshmode.mesh.processing import merge_disjoint_meshes
+    mesh = merge_disjoint_meshes([mesh1, mesh2])
+
+    adjacency_list = np.zeros((mesh.nelements,), dtype=set)
+    for elem in range(mesh.nelements):
+        adjacency_list[elem] = set()
+        starts = mesh.nodal_adjacency.neighbors_starts
+        for n in range(starts[elem], starts[elem + 1]):
+            adjacency_list[elem].add(mesh.nodal_adjacency.neighbors[n])
+
+    from pymetis import part_graph
+    (_, p) = part_graph(num_parts, adjacency=adjacency_list)
+    part_per_element = np.array(p)
+
+    from meshmode.mesh.processing import partition_mesh
+    new_meshes = [
+        partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)]
+
+    assert mesh.nelements == np.sum(
+        [new_meshes[i].nelements for i in range(num_parts)])
+
+# }}}
+
+
 # {{{ circle mesh
 
 def test_circle_mesh(do_plot=False):