diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index f49ccfbeb9bde6b8a8082bdeb5ed1af88c7736f5..49d27253d89e28aefd34f0b983c313cbb867577e 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -141,17 +141,30 @@ def partition_mesh(mesh, part_per_element, part_nr):
     from meshmode.mesh import Mesh
     part_mesh = Mesh(new_vertices, new_mesh_groups, facial_adjacency_groups=None)
 
-    return (part_mesh, queried_elems)
-
+    from meshmode.mesh import BTAG_ALL
+
+    for igrp in range(num_groups):
+        f_group = part_mesh.facial_adjacency_groups[igrp][None]
+        grp_elems = f_group.elements
+        grp_faces = f_group.element_faces
+        for elem_idx in range(len(grp_elems)):
+            elem = grp_elems[elem_idx]
+            face = grp_faces[elem_idx]
+            tag = -f_group.neighbors[elem_idx]
+            parent_elem = queried_elems[elem]
+            parent_group = 0
+            while parent_elem >= mesh.groups[parent_group].nelements:
+                parent_elem -= mesh.groups[parent_group].nelements
+                parent_group += 1
+            assert parent_group < num_groups, "oops..."
+            parent_facial_group = mesh.facial_adjacency_groups[parent_group][None]
+            idxs = np.where(parent_facial_group.elements == parent_elem)[0]
+            for parent_face in parent_facial_group.element_faces[idxs]:
+                if face == parent_face:
+                    f_group.neighbors[elem_idx] = -(tag ^ part_mesh.boundary_tag_bit(BTAG_ALL))
+                    #print("Boundary face", face, "of element", elem, "should be connected to", parent_elem, "in parent mesh.")
 
-def set_rank_boundaries(part_mesh, mesh, part_to_global):
-    """
-    Looks through facial_adjacency_groups in part_mesh.
-    If a boundary is found, then it is possible that it
-    used to be connected to other faces from mesh.
-    If this is the case, then part_mesh will have special
-    boundary_tags where faces used to be connected.
-    """
+    return (part_mesh, queried_elems)
 
 
 # {{{ orientations
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index a793743ce6ff50cc67aea4b31225a1b9e4df4e3d..ca413910d6bddbded9d8004a4634c9678caf3487 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -70,11 +70,11 @@ def test_partition_boxes_mesh():
     n = 5
     num_parts = 7
     from meshmode.mesh.generation import generate_regular_rect_mesh
-    mesh1 = generate_regular_rect_mesh(a=(0, 0, 0), b=(1, 1, 1), n=(n, n, n))
-    mesh2 = generate_regular_rect_mesh(a=(2, 2, 2), b=(3, 3, 3), n=(n, n, n))
+    mesh = generate_regular_rect_mesh(a=(0, 0, 0), b=(1, 1, 1), n=(n, n, n))
+    #mesh2 = generate_regular_rect_mesh(a=(2, 2, 2), b=(3, 3, 3), n=(n, n, n))
 
-    from meshmode.mesh.processing import merge_disjoint_meshes
-    mesh = merge_disjoint_meshes([mesh1, mesh2])
+    #from meshmode.mesh.processing import merge_disjoint_meshes
+    #mesh = merge_disjoint_meshes([mesh1, mesh2])
 
     adjacency_list = np.zeros((mesh.nelements,), dtype=set)
     for elem in range(mesh.nelements):
@@ -92,7 +92,24 @@ def test_partition_boxes_mesh():
         partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)]
 
     assert mesh.nelements == np.sum(
-        [new_meshes[i].nelements for i in range(num_parts)])
+        [new_meshes[i].nelements for i in range(num_parts)]), \
+        "part_mesh has the wrong number of elements"
+
+    print(count_BTAG_ALL(mesh))
+    print(np.sum([count_BTAG_ALL(new_meshes[i]) for i in range(num_parts)]))
+    assert count_BTAG_ALL(mesh) == np.sum(
+        [count_BTAG_ALL(new_meshes[i]) for i in range(num_parts)]), \
+        "part_mesh has the wrong number of BTAG_ALL boundaries"
+
+
+def count_BTAG_ALL(mesh):
+    num_bnds = 0
+    for adj_groups in mesh.facial_adjacency_groups:
+        bdry_group = adj_groups[None]
+        for mesh_tag in -bdry_group.neighbors:
+            if mesh_tag & mesh.boundary_tag_bit(BTAG_ALL) != 0:
+                num_bnds += 1
+    return num_bnds
 
 # }}}