diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 81d8526b977d7883ab75ba89da71a64b9558e9f8..e09d2d54ad8bf3efb22d64ab9490d7d08dcca6b3 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -37,6 +37,7 @@ __doc__ = """
 .. autofunction:: perform_flips
 .. autofunction:: find_bounding_box
 .. autofunction:: merge_disjoint_meshes
+.. autofunction:: split_mesh_groups
 .. autofunction:: map_mesh
 .. autofunction:: affine_map
 """
@@ -572,45 +573,60 @@ def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False):
 # {{{ split meshes
 
 
-def split_mesh_groups(mesh, element_flags):
-    """Split all the groups in *mesh* in two according to the values of
-    *element_flags*.
+def split_mesh_groups(mesh, element_flags, return_subgroup_mapping=False):
+    """Split all the groups in *mesh* in according to the values of
+    *element_flags*. The element flags are expected to be integers
+    defining, for each group, how the elements are to be split into
+    subgroups. For example, a single-group mesh with flags
+
+        .. code::
+
+            element_flags = [0, 0, 0, 42, 42, 42, 0, 0, 0, 41, 41, 41]
+
+    will create three subgroups. The integer flags need not be increasing
+    or contiguous and can repeat across different groups (i.e. they are
+    group-local).
 
     :arg element_flags: a :class:`numpy.ndarray` with
         :attr:`~meshmode.mesh.Mesh.nelements` entries
-        indicating by their *Boolean* value how the elements in a group
-        are to be split.
+        indicating how the elements in a group are to be split.
 
     :returns: a :class:`~meshmode.mesh.Mesh` where each group has been split
-        according to flags in *element_flags*.
+        according to flags in *element_flags*. If *return_subgroup_mapping*
+        is *True*, it also returns a mapping of
+        ``(group_index, subgroup) -> new_group_index``
+
     """
     assert element_flags.shape == (mesh.nelements,)
 
     new_groups = []
-    for grp in mesh.groups:
-        mask = element_flags[
+    subgroup_to_group_map = {}
+
+    for igrp, grp in enumerate(mesh.groups):
+        grp_flags = element_flags[
                 grp.element_nr_base:grp.element_nr_base + grp.nelements]
+        unique_grp_flags = np.unique(grp_flags)
 
-        vertex_indices = grp.vertex_indices[mask, :].copy()
-        if vertex_indices.size > 0:
-            new_groups.append(grp.copy(
-                vertex_indices=vertex_indices,
-                nodes=grp.nodes[:, mask, :].copy()
-                ))
+        for flag in unique_grp_flags:
+            subgroup_to_group_map[igrp, flag] = len(new_groups)
 
-        vertex_indices = grp.vertex_indices[~mask, :].copy()
-        if vertex_indices.size > 0:
+            mask = grp_flags == flag
             new_groups.append(grp.copy(
-                vertex_indices=vertex_indices,
-                nodes=grp.nodes[:, ~mask, :].copy()
+                vertex_indices=grp.vertex_indices[mask, :].copy(),
+                nodes=grp.nodes[:, mask, :].copy()
                 ))
 
     from meshmode.mesh import Mesh
-    return Mesh(
+    mesh = Mesh(
             vertices=mesh.vertices,
             groups=new_groups,
             is_conforming=mesh.is_conforming)
 
+    if return_subgroup_mapping:
+        return mesh, subgroup_to_group_map
+    else:
+        return mesh
+
 # }}}
 
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index b6a564abb7647cc20db52054b0a567f19a341f0a..eb69cf7b5d85e004f534b45f6589bb7a5b19a6a9 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -1309,7 +1309,7 @@ def test_mesh_multiple_groups(ctx_factory, ambient_dim, visualize=False):
     from meshmode.mesh.processing import split_mesh_groups
     element_flags = np.any(
             mesh.vertices[0, mesh.groups[0].vertex_indices] < 0.0,
-            axis=1)
+            axis=1).astype(np.int)
     mesh = split_mesh_groups(mesh, element_flags)
 
     assert len(mesh.groups) == 2
@@ -1318,10 +1318,14 @@ def test_mesh_multiple_groups(ctx_factory, ambient_dim, visualize=False):
 
     if visualize and ambient_dim == 2:
         from meshmode.mesh.visualization import draw_2d_mesh
-        draw_2d_mesh(mesh, draw_element_numbers=True, draw_face_numbers=True,
+        draw_2d_mesh(mesh,
+                draw_vertex_numbers=False,
+                draw_element_numbers=True,
+                draw_face_numbers=False,
                 set_bounding_box=True)
+
         import matplotlib.pyplot as plt
-        plt.show()
+        plt.savefig("test_mesh_multiple_groups_2d_elements.png", dpi=300)
 
     from meshmode.discretization import Discretization
     from meshmode.discretization.poly_element import \