diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 49d27253d89e28aefd34f0b983c313cbb867577e..93943e1d70d4982e778ac921f31fa131dc8004a0 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -157,12 +157,13 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 parent_elem -= mesh.groups[parent_group].nelements
                 parent_group += 1
             assert parent_group < num_groups, "oops..."
-            parent_facial_group = mesh.facial_adjacency_groups[parent_group][None]
-            idxs = np.where(parent_facial_group.elements == parent_elem)[0]
-            for parent_face in parent_facial_group.element_faces[idxs]:
-                if face == parent_face:
-                    f_group.neighbors[elem_idx] = -(tag ^ part_mesh.boundary_tag_bit(BTAG_ALL))
-                    #print("Boundary face", face, "of element", elem, "should be connected to", parent_elem, "in parent mesh.")
+            parent_f_group = mesh.facial_adjacency_groups[parent_group]
+            for _, parent_facial_group in parent_f_group.items():
+                for idx in np.where(parent_facial_group.elements == parent_elem)[0]:
+                    if parent_facial_group.neighbors[idx] >= 0:
+                        if face == parent_facial_group.element_faces[idx]:
+                            f_group.neighbors[elem_idx] = -(tag & ~part_mesh.boundary_tag_bit(BTAG_ALL))
+                            #print("Boundary face", face, "of element", elem, "should be connected to element", parent_elem, "in parent group", parent_group)
 
     return (part_mesh, queried_elems)
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index ca413910d6bddbded9d8004a4634c9678caf3487..97fc59de04843043556bb93fb51ee1fd8594622c 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -95,20 +95,19 @@ def test_partition_boxes_mesh():
         [new_meshes[i].nelements for i in range(num_parts)]), \
         "part_mesh has the wrong number of elements"
 
-    print(count_BTAG_ALL(mesh))
-    print(np.sum([count_BTAG_ALL(new_meshes[i]) for i in range(num_parts)]))
-    assert count_BTAG_ALL(mesh) == np.sum(
-        [count_BTAG_ALL(new_meshes[i]) for i in range(num_parts)]), \
+    assert count_btag_all(mesh) == np.sum(
+        [count_btag_all(new_meshes[i]) for i in range(num_parts)]), \
         "part_mesh has the wrong number of BTAG_ALL boundaries"
 
 
-def count_BTAG_ALL(mesh):
+def count_btag_all(mesh):
     num_bnds = 0
-    for adj_groups in mesh.facial_adjacency_groups:
-        bdry_group = adj_groups[None]
-        for mesh_tag in -bdry_group.neighbors:
-            if mesh_tag & mesh.boundary_tag_bit(BTAG_ALL) != 0:
-                num_bnds += 1
+    for adj_dict in mesh.facial_adjacency_groups:
+        for _, bdry_group in adj_dict.items():
+            for neighbors in bdry_group.neighbors:
+                if neighbors < 0:
+                    if -neighbors & mesh.boundary_tag_bit(BTAG_ALL) != 0:
+                        num_bnds += 1
     return num_bnds
 
 # }}}