diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 27cb9a032d064d47d57202aa11a5d7f75eb3fb8f..37e4ac264c28f509623465c1e967fc909c070113 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -70,7 +70,11 @@ def partition_mesh(mesh, part_per_element, part_nr):
     new_nodes = []
 
     # The set of vertex indices we need.
-    index_set = set()
+    # NOTE: There are two methods for producing required_indices.
+    #   Optimizations may come from further exploring these options.
+    #index_set = np.array([], dtype=int)
+    index_sets = np.array([], dtype=set)
+
     skip_groups = []
     num_prev_elems = 0
     start_idx = 0
@@ -103,13 +107,15 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 new_idx = j - start_idx
                 new_nodes[group_nr][i, new_idx, :] = mesh_group.nodes[i, elems, :]
 
-        index_set |= set(new_indices[group_nr].ravel())
+        #index_set = np.append(index_set, new_indices[group_nr].ravel())
+        index_sets = np.append(index_sets, set(new_indices[group_nr].ravel()))
 
         num_prev_elems += mesh_group.nelements
         start_idx = end_idx
 
     # A sorted np.array of vertex indices we need (without duplicates).
-    required_indices = np.array(list(index_set))
+    #required_indices = np.unique(np.sort(index_set))
+    required_indices = np.array(list(set.union(*index_sets)))
 
     new_vertices = np.zeros((mesh.ambient_dim, len(required_indices)))
     for dim in range(mesh.ambient_dim):
@@ -124,8 +130,6 @@ def partition_mesh(mesh, part_per_element, part_nr):
                     new_indices[group_nr][i, j] = np.where(
                         required_indices == original_index)[0]
 
-    from meshmode.mesh import SimplexElementGroup, Mesh
-
     new_mesh_groups = []
     for group_nr in range(num_groups):
         if group_nr not in skip_groups:
@@ -134,6 +138,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 type(mesh_group)(mesh_group.order, new_indices[group_nr],
                     new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
 
+    from meshmode.mesh import Mesh
     part_mesh = Mesh(new_vertices, new_mesh_groups)
 
     return (part_mesh, queried_elems)