diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 93943e1d70d4982e778ac921f31fa131dc8004a0..762cf41c8a77cec95af15b430e544249793c16c8 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -138,8 +138,12 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 type(mesh_group)(mesh_group.order, new_indices[group_nr],
                     new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
 
+    num_parts = np.max(part_per_element)
+    boundary_tags = list(range(num_parts))
+
     from meshmode.mesh import Mesh
-    part_mesh = Mesh(new_vertices, new_mesh_groups, facial_adjacency_groups=None)
+    part_mesh = Mesh(new_vertices, new_mesh_groups, \
+        facial_adjacency_groups=None, boundary_tags=boundary_tags)
 
     from meshmode.mesh import BTAG_ALL
 
@@ -162,7 +166,12 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 for idx in np.where(parent_facial_group.elements == parent_elem)[0]:
                     if parent_facial_group.neighbors[idx] >= 0:
                         if face == parent_facial_group.element_faces[idx]:
-                            f_group.neighbors[elem_idx] = -(tag & ~part_mesh.boundary_tag_bit(BTAG_ALL))
+                            rank_neighbor = parent_facial_group.neighbors[idx]
+                            # TODO: With mulitple groups, rank_neighbors will be wrong.
+                            neighbor_part_num = part_per_element[rank_neighbor]
+                            tag = tag & ~part_mesh.boundary_tag_bit(BTAG_ALL)
+                            tag = tag | part_mesh.boundary_tag_bit(neighbor_part_num)
+                            f_group.neighbors[elem_idx] = -tag
                             #print("Boundary face", face, "of element", elem, "should be connected to element", parent_elem, "in parent group", parent_group)
 
     return (part_mesh, queried_elems)