diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 7f520f44bf583c7a57ef3b8b5383884c144c4e63..05ce910e58ba4eac79c57a3194091bec1e5ea962 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -56,6 +56,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
     """
     assert len(part_per_element) == mesh.nelements
 
+    # Contains the indices of the elements requested.
     queried_elems = np.where(np.array(part_per_element) == part_nr)[0]
 
     num_groups = len(mesh.groups)
@@ -68,39 +69,61 @@ def partition_mesh(mesh, part_per_element, part_nr):
     start_idx = 0
     for group_nr in range(num_groups):
         mesh_group = mesh.groups[group_nr]
-        #end_idx = start_idx + np.where(queried_elems >= mesh_group.nelements - start_idx)[0] - 1
+
+        # Find the index of first element in the next group
         end_idx = len(queried_elems)
+        for idx in range(len(queried_elems)):
+            if queried_elems[idx] - start_idx > mesh_group.nelements:
+                end_idx = idx
+                break
 
-        new_indices.append(mesh_group.vertex_indices[queried_elems[start_idx:end_idx]])
+        elems = queried_elems[start_idx:end_idx]
+        new_indices.append(mesh_group.vertex_indices[elems])
 
-        new_nodes.append(np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes)))
+        new_nodes.append(
+            np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes)))
         for i in range(mesh.ambient_dim):
             for j in range(start_idx, end_idx):
-                new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, queried_elems[j], :]
+                elems = queried_elems[j]
+                new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, elems, :]
 
         index_set |= set(new_indices[group_nr].ravel())
 
         start_idx = end_idx
 
-    # A sorted np.array of vertex indices we need in our new mesh (without duplicates).
+    # A sorted np.array of vertex indices we need (without duplicates).
     required_indices = np.array(list(index_set))
 
     new_vertices = np.zeros((mesh.ambient_dim, len(required_indices)))
     for dim in range(mesh.ambient_dim):
         new_vertices[dim] = mesh.vertices[dim][required_indices]
 
-    # We need to update our indices to be in range [0, len(mesh_group.nelements)].
+    # Our indices need to be in range [0, len(mesh_group.nelements)].
     for group_nr in range(num_groups):
         for i in range(len(new_indices[group_nr])):
             for j in range(len(new_indices[group_nr][0])):
-                new_indices[group_nr][i, j] = np.where(required_indices == new_indices[group_nr][i, j])[0]
+                original_index = new_indices[group_nr][i, j]
+                new_indices[group_nr][i, j] = np.where(
+                    required_indices == original_index)[0]
+
+    """
+    print("mesh vertices: ", mesh.vertices)
+    print("mesh indices: ", mesh.groups[0].vertex_indices)
+    print("mesh nodes: ", mesh.groups[0].nodes)
+    print("queried_elems: ", queried_elems)
+    print("indices: ", new_indices[0])
+    print("nodes: ", new_nodes[0])
+    print("vertices: ", new_vertices)
+    """
 
     from meshmode.mesh import MeshElementGroup, Mesh
 
     new_mesh_groups = []
     for group_nr in range(num_groups):
         mesh_group = mesh.groups[group_nr]
-        new_mesh_groups.append(MeshElementGroup(mesh_group.order, new_indices[group_nr], new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes, dim=mesh_group.dim))
+        new_mesh_groups.append(
+            MeshElementGroup(mesh_group.order, new_indices[group_nr],
+                new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
 
     part_mesh = Mesh(new_vertices, new_mesh_groups)
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index b5b2cde37195eee8e8398ea4de808a76224f22ae..99a9284e291ffd661cb0b0b878e6b08ade531945 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -47,38 +47,44 @@ import pytest
 import logging
 logger = logging.getLogger(__name__)
 
+
 # {{{ partition_mesh
 
-@pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"])
-@pytest.mark.parametrize("npoints", [10, 1000])
-def test_partition_mesh(mesh_type=None, npoints=None):
+#@pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"])
+#@pytest.mark.parametrize("npoints", [10, 1000])
+def test_partition_mesh(mesh_type, npoints):
     from meshmode.mesh.generation import generate_torus
-
     my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2)
 
     part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0])
 
-    print(my_mesh.nelements, my_mesh.groups[0].nelements)
-
-    #(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1)
     from meshmode.mesh.processing import partition_mesh
-    partition_mesh(my_mesh, part_per_element, 1)
-    """from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish
+    (part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1)
+
+
+    """
+    from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish
 
     if mesh_type == "cloverleaf":
         mesh = make_curve_mesh(cloverleaf, np.linspace(0, 1, npoints), order=3)
     elif mesh_type == "starfish":
         mesh = make_curve_mesh(starfish, np.linspace(0, 1, npoints), order=3)
 
+    #from meshmode.mesh.visualization import draw_2d_mesh
+    #draw_2d_mesh(mesh)
+    #import matplotlib.pyplot as pt
+    #pt.show()
+
     #TODO: Create an actuall adjacency list from the mesh.
-    adjacency_list = mesh.nodal_adjacency
+    adjacency_list = mesh.facial_adjacency_groups
+    print(adjacency_list)
 
     from pymetis import part_graph
     part_per_element = np.array(part_graph(3, adjacency=adjacency_list))
 
     from meshmode.mesh.processing import partition_mesh
     (part_mesh, part_to_global) = partition_mesh(mesh, part_per_element, 0)
-"""
+    """
 # }}}