diff --git a/meshmode/discretization/connection/__init__.py b/meshmode/discretization/connection/__init__.py
index bdb0f5250256ad6e8a67a10d9b735c3ff085da4b..ff19c3506bdb193f872df1a71e3c502ab157b725 100644
--- a/meshmode/discretization/connection/__init__.py
+++ b/meshmode/discretization/connection/__init__.py
@@ -35,7 +35,7 @@ from meshmode.discretization.connection.face import (
         FRESTR_INTERIOR_FACES, FRESTR_ALL_FACES,
         make_face_restriction, make_face_to_all_faces_embedding)
 from meshmode.discretization.connection.opposite_face import \
-        make_opposite_face_connection
+        make_opposite_face_connection, make_partition_connection
 from meshmode.discretization.connection.refinement import \
         make_refinement_connection
 
@@ -51,6 +51,7 @@ __all__ = [
         "make_face_restriction",
         "make_face_to_all_faces_embedding",
         "make_opposite_face_connection",
+        "make_partition_connection",
         "make_refinement_connection"
         ]
 
@@ -66,6 +67,7 @@ __doc__ = """
 .. autofunction:: make_face_to_all_faces_embedding
 
 .. autofunction:: make_opposite_face_connection
+.. autofunction:: make_partition_connection
 
 .. autofunction:: make_refinement_connection
 
diff --git a/meshmode/discretization/connection/opposite_face.py b/meshmode/discretization/connection/opposite_face.py
index 6ce70b2a11b148ce6a1e2cca6f7588ac2549aada..bfed445c5a5efac8167b5b162e9bc9269576fbab 100644
--- a/meshmode/discretization/connection/opposite_face.py
+++ b/meshmode/discretization/connection/opposite_face.py
@@ -392,4 +392,226 @@ def make_opposite_face_connection(volume_to_bdry_conn):
 
 # }}}
 
+
+def _make_cross_partition_batch(queue, vol_to_bdry_conns, adj,
+                            i_tgt_part, i_tgt_grp, i_tgt_elem, i_tgt_face):
+    """
+    Creates a batch that transfers data to a face from a face of another partition.
+
+    :arg queue:
+    :arg vol_to_bdry_conns: A list of :class:`Direct` for each partition.
+    :arg adj: :class:`InterPartitionAdj` of partition `i_tgt_part`.
+    :arg i_tgt_part: The target partition number.
+    :arg i_tgt_grp:
+    :arg i_tgt_elem:
+    :arg i_tgt_face:
+
+    :returns: ???
+    """
+
+    (i_src_part, i_src_grp, i_src_elem, i_src_face) =\
+                        adj.get_neighbor(i_tgt_elem, i_tgt_face)
+
+    src_bdry_discr = vol_to_bdry_conns[i_src_part].to_discr
+    tgt_bdry_discr = vol_to_bdry_conns[i_tgt_part].to_discr
+
+    tgt_bdry_nodes = (
+            # FIXME: This should view-then-transfer (but PyOpenCL doesn't do
+            # non-contiguous transfers for now).
+            tgt_bdry_discr.groups[i_tgt_grp].view(
+                tgt_bdry_discr.nodes().get(queue=queue))
+            [:, i_tgt_elem])
+
+    ambient_dim, nelements, n_tgt_unit_nodes = tgt_bdry_nodes.shape
+
+    # (ambient_dim, nelements, nfrom_unit_nodes)
+    src_bdry_nodes = (
+            # FIXME: This should view-then-transfer (but PyOpenCL doesn't do
+            # non-contiguous transfers for now).
+            src_bdry_discr.groups[i_src_grp].view(
+                src_bdry_discr.nodes().get(queue=queue))
+            [:, i_src_elem])
+
+    tol = 1e4 * np.finfo(tgt_bdry_nodes.dtype).eps
+
+    src_mesh_grp = src_bdry_discr.mesh.groups[i_src_grp]
+    src_grp = src_bdry_discr.groups[i_src_grp]
+
+    dim = src_grp.dim
+
+    initial_guess = np.mean(src_mesh_grp.vertex_unit_coordinates(), axis=0)
+
+    src_unit_nodes = np.empty((dim, nelements, n_tgt_unit_nodes))
+    src_unit_nodes[:] = initial_guess.reshape(-1, 1, 1)
+
+    import modepy as mp
+    src_vdm = mp.vandermonde(src_grp.basis(), src_grp.unit_nodes)
+    src_inv_t_vdm = la.inv(src_vdm.T)
+    src_nfuncs = len(src_grp.basis())
+
+    def apply_map(unit_nodes):
+        # unit_nodes: (dim, nelements, nto_unit_nodes)
+        # basis_at_unit_nodes
+        basis_at_unit_nodes = np.empty((src_nfuncs, nelements, n_tgt_unit_nodes))
+        for i, f in enumerate(src_grp.basis()):
+            basis_at_unit_nodes[i] = (
+                    f(unit_nodes.reshape(dim, -1))
+                    .reshape(nelements, n_tgt_unit_nodes))
+        intp_coeffs = np.einsum("fj,jet->fet", src_inv_t_vdm, basis_at_unit_nodes)
+        # If we're interpolating 1, we had better get 1 back.
+        one_deviation = np.abs(np.sum(intp_coeffs, axis=0) - 1)
+        assert (one_deviation < tol).all(), np.max(one_deviation)
+        return np.einsum("fet,aef->aet", intp_coeffs, src_bdry_nodes)
+
+    def get_map_jacobian(unit_nodes):
+        # unit_nodes: (dim, nelements, nto_unit_nodes)
+        # basis_at_unit_nodes
+        dbasis_at_unit_nodes = np.empty(
+                (dim, src_nfuncs, nelements, n_tgt_unit_nodes))
+        for i, df in enumerate(src_grp.grad_basis()):
+            df_result = df(unit_nodes.reshape(dim, -1))
+            for rst_axis, df_r in enumerate(df_result):
+                dbasis_at_unit_nodes[rst_axis, i] = (
+                        df_r.reshape(nelements, n_tgt_unit_nodes))
+        dintp_coeffs = np.einsum(
+                "fj,rjet->rfet", src_inv_t_vdm, dbasis_at_unit_nodes)
+        return np.einsum("rfet,aef->raet", dintp_coeffs, src_bdry_nodes)
+
+    logger.info("make_partition_connection: begin gauss-newton")
+    niter = 0
+    while True:
+        resid = apply_map(src_unit_nodes) - tgt_bdry_nodes
+        df = get_map_jacobian(src_unit_nodes)
+        df_inv_resid = np.empty_like(src_unit_nodes)
+        # For the 1D/2D accelerated versions, we'll use the normal
+        # equations and Cramer's rule. If you're looking for high-end
+        # numerics, look no further than meshmode.
+        if dim == 1:
+            # A is df.T
+            ata = np.einsum("iket,jket->ijet", df, df)
+            atb = np.einsum("iket,ket->iet", df, resid)
+            df_inv_resid = atb / ata[0, 0]
+        elif dim == 2:
+            # A is df.T
+            ata = np.einsum("iket,jket->ijet", df, df)
+            atb = np.einsum("iket,ket->iet", df, resid)
+            det = ata[0, 0]*ata[1, 1] - ata[0, 1]*ata[1, 0]
+            df_inv_resid = np.empty_like(src_unit_nodes)
+            df_inv_resid[0] = 1/det * (ata[1, 1] * atb[0] - ata[1, 0]*atb[1])
+            df_inv_resid[1] = 1/det * (-ata[0, 1] * atb[0] + ata[0, 0]*atb[1])
+        else:
+            # The boundary of a 3D mesh is 2D, so that's the
+            # highest-dimensional case we genuinely care about.
+            #
+            # This stinks, performance-wise, because it's not vectorized.
+            # But we'll only hit it for boundaries of 4+D meshes, in which
+            # case... good luck. :)
+            for e in range(nelements):
+                for t in range(n_tgt_unit_nodes):
+                    df_inv_resid[:, e, t], _, _, _ = \
+                            la.lstsq(df[:, :, e, t].T, resid[:, e, t])
+        src_unit_nodes = src_unit_nodes - df_inv_resid
+        max_resid = np.max(np.abs(resid))
+        logger.debug("gauss-newton residual: %g" % max_resid)
+        if max_resid < tol:
+            logger.info("make_partition_connection: gauss-newton: done, "
+                    "final residual: %g" % max_resid)
+            break
+        niter += 1
+        if niter > 10:
+            raise RuntimeError("Gauss-Newton (for finding partition_connection "
+                    "reference coordinates) did not converge")
+
+    def to_dev(ary):
+        return cl.array.to_device(queue, ary, array_queue=None)
+
+    done_elements = np.zeros(nelements, dtype=np.bool)
+
+    # TODO: Still need to figure out what's happening here.
+    while True:
+        todo_elements, = np.where(~done_elements)
+        if not len(todo_elements):
+            return
+        template_unit_nodes = src_unit_nodes[:, todo_elements[0], :]
+        unit_node_dist = np.max(np.max(np.abs(
+                src_unit_nodes[:, todo_elements, :]
+                -
+                template_unit_nodes.reshape(dim, 1, -1)),
+                axis=2), axis=0)
+        close_els = todo_elements[unit_node_dist < tol]
+        done_elements[close_els] = True
+        unit_node_dist = np.max(np.max(np.abs(
+                src_unit_nodes[:, todo_elements, :]
+                -
+                template_unit_nodes.reshape(dim, 1, -1)),
+                axis=2), axis=0)
+
+        from meshmode.discretization.connection import InterpolationBatch
+        yield InterpolationBatch(
+                from_group_index=i_src_grp,
+                from_element_indices=to_dev(from_bdry_element_indices[close_els]),
+                to_element_indices=to_dev(to_bdry_element_indices[close_els]),
+                result_unit_nodes=template_unit_nodes,
+                to_element_face=None)
+
+
+def make_partition_connection(vol_to_bdry_conns):
+    """
+    Given a list of boundary restriction connections *volume_to_bdry_conn*,
+    return a :class:`DirectDiscretizationConnection` that performs data
+    exchange across adjacent faces of different partitions.
+
+    :arg vol_to_bdry_conns: A list of *volume_to_bdry_conn* corresponding to
+                                a partition of a parent mesh.
+
+    :returns: A list of :class:`DirectDiscretizationConnection` corresponding to
+                each partition.
+    """
+
+    disc_conns = []
+    from meshmode.discretization.connection import (
+            DirectDiscretizationConnection, DiscretizationConnectionElementGroup)
+
+    # My intuition tells me that this should not live inside a for loop.
+    # However, I need to grab a cl_context. I'll assume that each context from
+    # each partition is the same and I'll use the first one.
+    cl_context = vol_to_bdry_conns[0].from_discr.cl_context
+    with cl.CommandQueue(cl_context) as queue:
+        # Create a list of batches. Each batch contains interpolation
+        #   data from one partition to another.
+        for i_tgt_part, tgt_vol_conn in enumerate(vol_to_bdry_conns):
+            bdry_discr = tgt_vol_conn.to_discr
+            tgt_mesh = tgt_vol_conn.to_discr.mesh
+            ngroups = len(tgt_mesh.groups)
+            part_batches = [[] for _ in range(ngroups)]
+            for tgt_group_num, adj in enumerate(tgt_mesh.interpart_adj_groups):
+                for idx, tgt_elem in enumerate(adj.elements):
+                    tgt_face = adj.element_faces[idx]
+
+                    # We need to create a batch using the
+                    # neighboring face, element, and group
+                    # I'm not sure how I would do this.
+                    # My guess is that it would look
+                    # something like _make_cross_face_batches
+                    part_batches[tgt_group_num].append(
+                            _make_cross_partition_batch(
+                                queue,
+                                vol_to_bdry_conns,
+                                adj,
+                                i_tgt_part,
+                                tgt_group_num,
+                                tgt_elem,
+                                tgt_face))
+
+            # Make one Discr connection for each partition.
+            disc_conns.append(DirectDiscretizationConnection(
+                    from_discr=bdry_discr,
+                    to_discr=bdry_discr,
+                    groups=[
+                        DiscretizationConnectionElementGroup(batches=batches)
+                        for batches in part_batches],
+                    is_surjective=True))
+
+    return disc_conns
+
 # vim: foldmethod=marker
diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py
index d3c9c8cfd11e4040ec7d507354776b9682f0fcae..eb4ee95720b176255855a96460c7ce99b5fe874c 100644
--- a/meshmode/mesh/__init__.py
+++ b/meshmode/mesh/__init__.py
@@ -54,6 +54,7 @@ Predefined Boundary tags
 .. autoclass:: BTAG_ALL
 .. autoclass:: BTAG_REALLY_ALL
 .. autoclass:: BTAG_NO_BOUNDARY
+.. autoclass:: BTAG_PARTITION
 """
 
 
@@ -88,6 +89,35 @@ class BTAG_NO_BOUNDARY(object):  # noqa
     pass
 
 
+class BTAG_PARTITION(object):  # noqa
+    """
+    A boundary tag indicating that this edge is adjacent to an element of
+    another :class:`Mesh`. The partition number of the adjacent mesh
+    is given by ``part_nr``.
+
+    .. attribute:: part_nr
+
+    .. versionadded:: 2017.1
+    """
+    def __init__(self, part_nr):
+        self.part_nr = int(part_nr)
+
+    # TODO is this acceptable?
+    # __eq__ is also defined so maybe the hash value isn't too important
+    # for dictionaries.
+    def __hash__(self):
+        return self.part_nr
+
+    def __eq__(self, other):
+        if isinstance(other, BTAG_PARTITION):
+            return self.part_nr == other.part_nr
+        else:
+            return False
+
+    def __nq__(self, other):
+        return not self.__eq__(other)
+
+
 SYSTEM_TAGS = set([BTAG_NONE, BTAG_ALL, BTAG_REALLY_ALL, BTAG_NO_BOUNDARY])
 
 # }}}
@@ -382,6 +412,80 @@ class NodalAdjacency(Record):
 # }}}
 
 
+# {{{ partition adjacency
+
+class InterPartitionAdj():
+    """
+    Describes facial adjacency information of elements in one :class:`Mesh` to
+    elements in another :class:`Mesh`. The element's boundary tag gives the
+    partition that it is connected to.
+
+    .. attribute:: elements
+
+        `:class:Mesh`-local element numbers that have neighbors.
+
+    .. attribute:: element_faces
+
+        ``element_faces[i]`` is the face of ``elements[i]`` that has a neighbor.
+
+    .. attribute:: neighbors
+
+        ``neighbors[i]`` gives the element number within the neighboring partiton
+        of the element connected to ``elements[i]``.
+
+    .. attribute:: neighbor_faces
+
+        ``neighbor_faces[i]`` gives face index within the neighboring partition
+        of the face connected to ``elements[i]``
+
+    .. automethod:: add_connection
+    .. automethod:: get_neighbor
+
+    .. versionadded:: 2017.1
+    """
+
+    def __init__(self):
+        self.elements = []
+        self.element_faces = []
+        self.neighbors = []
+        self.neighbor_faces = []
+        self.neighbor_groups = []
+        self.part_indices = []
+
+    def add_connection(self, elem, face, part_idx, neighbor_group, neighbor_elem, neighbor_face):
+        """
+        Adds a connection from ``elem`` and ``face`` within :class:`Mesh` to
+        ``neighbor_elem`` and ``neighbor_face`` of another neighboring partion
+        of type :class:`Mesh`.
+        :arg elem
+        :arg face
+        :arg part_idx
+        :arg neighbor_elem
+        :arg neighbor_face
+        """
+        self.elements.append(elem)
+        self.element_faces.append(face)
+        self.part_indices.append(part_idx)
+        self.neighbors.append(neighbor_elem)
+        self.neighbor_groups.append(neighbor_group)
+        self.neighbor_faces.append(neighbor_face)
+
+    def get_neighbor(self, elem, face):
+        """
+        :arg elem
+        :arg face
+        :returns: A tuple ``(part_idx, neighbor_group, neighbor_elem, neighbor_face)`` of 
+                    neighboring elements within another :class:`Mesh`.
+        """
+        for idx in range(len(self.elements)):
+            if elem == self.elements[idx] and face == self.element_faces[idx]:
+                return (self.part_indices[idx], self.neighbor_groups[idx],
+                         self.neighbors[idx], self.neighbor_faces[idx])
+        raise RuntimeError("This face does not have a neighbor")
+
+# }}}
+
+
 # {{{ facial adjacency
 
 class FacialAdjacencyGroup(Record):
@@ -533,6 +637,7 @@ class Mesh(Record):
             node_vertex_consistency_tolerance=None,
             nodal_adjacency=False,
             facial_adjacency_groups=False,
+            interpart_adj_groups=False,
             boundary_tags=None,
             vertex_id_dtype=np.int32,
             element_id_dtype=np.int32):
@@ -563,6 +668,7 @@ class Mesh(Record):
             will result in exceptions. Lastly, a data structure as described in
             :attr:`facial_adjacency_groups` may be passed.
         """
+
         el_nr = 0
         node_nr = 0
 
@@ -613,6 +719,7 @@ class Mesh(Record):
                 self, vertices=vertices, groups=new_groups,
                 _nodal_adjacency=nodal_adjacency,
                 _facial_adjacency_groups=facial_adjacency_groups,
+                interpart_adj_groups=interpart_adj_groups,
                 boundary_tags=boundary_tags,
                 btag_to_index=btag_to_index,
                 vertex_id_dtype=np.dtype(vertex_id_dtype),
@@ -742,6 +849,7 @@ class Mesh(Record):
                         == other._nodal_adjacency)
                 and (self._facial_adjacency_groups
                         == other._facial_adjacency_groups)
+                and self.interpart_adj_groups == other.interpart_adj_groups
                 and self.boundary_tags == other.boundary_tags)
 
     def __ne__(self, other):
@@ -922,6 +1030,7 @@ def _compute_facial_adjacency_from_vertices(mesh):
 
         for ineighbor_group in range(len(mesh.groups)):
             nb_count = group_count.get((igroup, ineighbor_group))
+            # FIXME nb_count is None sometimes when it maybe shouldn't be.
             if nb_count is not None:
                 elements = np.empty(nb_count, dtype=mesh.element_id_dtype)
                 element_faces = np.empty(nb_count, dtype=mesh.face_id_dtype)
diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 31abb9beb27d68090ffad264c901b01c69729b63..48effb1d1fe99a750a1d1964d47adb45b3c5a6f2 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -80,8 +80,8 @@ def partition_mesh(mesh, part_per_element, part_nr):
     skip_groups = []
     num_prev_elems = 0
     start_idx = 0
-    for group_nr in range(num_groups):
-        mesh_group = mesh.groups[group_nr]
+    for group_num in range(num_groups):
+        mesh_group = mesh.groups[group_num]
 
         # Find the index of first element in the next group
         end_idx = len(queried_elems)
@@ -91,7 +91,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
                 break
 
         if start_idx == end_idx:
-            skip_groups.append(group_nr)
+            skip_groups.append(group_num)
             new_indices.append(np.array([]))
             new_nodes.append(np.array([]))
             num_prev_elems += mesh_group.nelements
@@ -107,10 +107,10 @@ def partition_mesh(mesh, part_per_element, part_nr):
             for j in range(start_idx, end_idx):
                 elems = queried_elems[j] - num_prev_elems
                 new_idx = j - start_idx
-                new_nodes[group_nr][i, new_idx, :] = mesh_group.nodes[i, elems, :]
+                new_nodes[group_num][i, new_idx, :] = mesh_group.nodes[i, elems, :]
 
-        #index_set = np.append(index_set, new_indices[group_nr].ravel())
-        index_sets = np.append(index_sets, set(new_indices[group_nr].ravel()))
+        #index_set = np.append(index_set, new_indices[group_num].ravel())
+        index_sets = np.append(index_sets, set(new_indices[group_num].ravel()))
 
         num_prev_elems += mesh_group.nelements
         start_idx = end_idx
@@ -124,24 +124,79 @@ def partition_mesh(mesh, part_per_element, part_nr):
         new_vertices[dim] = mesh.vertices[dim][required_indices]
 
     # Our indices need to be in range [0, len(mesh.nelements)].
-    for group_nr in range(num_groups):
-        if group_nr not in skip_groups:
-            for i in range(len(new_indices[group_nr])):
-                for j in range(len(new_indices[group_nr][0])):
-                    original_index = new_indices[group_nr][i, j]
-                    new_indices[group_nr][i, j] = np.where(
-                        required_indices == original_index)[0]
+    for group_num in range(num_groups):
+        if group_num not in skip_groups:
+            for i in range(len(new_indices[group_num])):
+                for j in range(len(new_indices[group_num][0])):
+                    original_index = new_indices[group_num][i, j]
+                    new_indices[group_num][i, j] = np.where(
+                            required_indices == original_index)[0]
 
     new_mesh_groups = []
-    for group_nr in range(num_groups):
-        if group_nr not in skip_groups:
-            mesh_group = mesh.groups[group_nr]
+    for group_num in range(num_groups):
+        if group_num not in skip_groups:
+            mesh_group = mesh.groups[group_num]
             new_mesh_groups.append(
-                type(mesh_group)(mesh_group.order, new_indices[group_nr],
-                    new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
+                type(mesh_group)(mesh_group.order, new_indices[group_num],
+                    new_nodes[group_num], unit_nodes=mesh_group.unit_nodes))
+
+    from meshmode.mesh import BTAG_ALL, BTAG_PARTITION
+    boundary_tags = [BTAG_PARTITION(n) for n in range(np.max(part_per_element))]
 
     from meshmode.mesh import Mesh
-    part_mesh = Mesh(new_vertices, new_mesh_groups)
+    part_mesh = Mesh(new_vertices, new_mesh_groups,
+        facial_adjacency_groups=None, boundary_tags=boundary_tags)
+
+    # FIXME I get errors when I try to copy part_mesh.
+    from meshmode.mesh import InterPartitionAdj
+    part_mesh.interpart_adj_groups = [
+                    InterPartitionAdj() for _ in range(num_groups)]
+
+    for igrp in range(num_groups):
+        elem_base = part_mesh.groups[igrp].element_nr_base
+        boundary_adj = part_mesh.facial_adjacency_groups[igrp][None]
+        boundary_elems = boundary_adj.elements
+        boundary_faces = boundary_adj.element_faces
+        for elem_idx in range(len(boundary_elems)):
+            elem = boundary_elems[elem_idx]
+            face = boundary_faces[elem_idx]
+            tags = -boundary_adj.neighbors[elem_idx]
+            assert tags >= 0, "Expected boundary tag in adjacency group."
+            parent_elem = queried_elems[elem]
+            parent_group_num = 0
+            while parent_elem >= mesh.groups[parent_group_num].nelements:
+                parent_elem -= mesh.groups[parent_group_num].nelements
+                parent_group_num += 1
+            assert parent_group_num < num_groups, "Unable to find neighbor."
+            parent_grp_elem_base = mesh.groups[parent_group_num].element_nr_base
+            parent_adj = mesh.facial_adjacency_groups[parent_group_num]
+            for n_grp_num, parent_facial_group in parent_adj.items():
+                for idx in np.where(parent_facial_group.elements == parent_elem)[0]:
+                    if parent_facial_group.neighbors[idx] >= 0 and \
+                            parent_facial_group.element_faces[idx] == face:
+                        rank_neighbor = (parent_facial_group.neighbors[idx]
+                                         + parent_grp_elem_base)
+                        rank_neighbor_face = parent_facial_group.neighbor_faces[idx]
+
+                        n_part_num = part_per_element[rank_neighbor]
+                        tags = tags & ~part_mesh.boundary_tag_bit(BTAG_ALL)
+                        tags = tags | part_mesh.boundary_tag_bit(
+                                                    BTAG_PARTITION(n_part_num))
+                        boundary_adj.neighbors[elem_idx] = -tags
+
+                        # Find the neighbor element from the other partition
+                        n_elem = np.count_nonzero(
+                                    part_per_element[:rank_neighbor] == n_part_num)
+
+                        # TODO Test if this works with multiple groups
+                        # Do I need to add the element number base?
+                        part_mesh.interpart_adj_groups[igrp].add_connection(
+                            elem + elem_base,
+                            face,
+                            n_part_num,
+                            n_grp_num,
+                            n_elem,
+                            rank_neighbor_face)
 
     return part_mesh, queried_elems
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 413a107d1a3d33b52b1a0fd2eea2047822db54ed..cb36e4da372d5193387861697bfc6497a26d89c0 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -48,33 +48,59 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-# {{{ partition_mesh
+def test_partition_interpolation(ctx_getter):
+    cl_ctx = ctx_getter()
+    order = 4
+    group_factory = PolynomialWarpAndBlendGroupFactory(order)
+    n = 3
+    dim = 2
+    num_parts = 7
+    from meshmode.mesh.generation import generate_warped_rect_mesh
+    mesh = generate_warped_rect_mesh(dim, order=order, n=n)
 
-def test_partition_torus_mesh():
-    from meshmode.mesh.generation import generate_torus
-    my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2)
+    adjacency_list = np.zeros((mesh.nelements,), dtype=set)
+    for elem in range(mesh.nelements):
+        adjacency_list[elem] = set()
+        starts = mesh.nodal_adjacency.neighbors_starts
+        for n in range(starts[elem], starts[elem + 1]):
+            adjacency_list[elem].add(mesh.nodal_adjacency.neighbors[n])
 
-    part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0])
+    from pymetis import part_graph
+    (_, p) = part_graph(num_parts, adjacency=adjacency_list)
+    part_per_element = np.array(p)
 
     from meshmode.mesh.processing import partition_mesh
-    (part_mesh0, _) = partition_mesh(my_mesh, part_per_element, 0)
-    (part_mesh1, _) = partition_mesh(my_mesh, part_per_element, 1)
-    (part_mesh2, _) = partition_mesh(my_mesh, part_per_element, 2)
+    part_meshes = [
+        partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)]
 
-    assert part_mesh0.nelements == 2
-    assert part_mesh1.nelements == 4
-    assert part_mesh2.nelements == 2
+    from meshmode.discretization import Discretization
+    vol_discrs = [Discretization(cl_ctx, part_meshes[i], group_factory)
+                    for i in range(num_parts)]
+
+    from meshmode.discretization.connection import make_face_restriction
+    bdry_connections = [make_face_restriction(vol_discrs[i], group_factory,
+                            FRESTR_INTERIOR_FACES) for i in range(num_parts)]
 
+    from meshmode.discretization.connection import make_partition_connection
+    connections = make_partition_connection(bdry_connections)
 
-def test_partition_boxes_mesh():
+    from meshmode.discretization.connection import check_connection
+    for conn in connections:
+        check_connection(conn)
+
+
+# {{{ partition_mesh
+
+def test_partition_mesh():
     n = 5
     num_parts = 7
     from meshmode.mesh.generation import generate_regular_rect_mesh
-    mesh1 = generate_regular_rect_mesh(a=(0, 0, 0), b=(1, 1, 1), n=(n, n, n))
-    mesh2 = generate_regular_rect_mesh(a=(2, 2, 2), b=(3, 3, 3), n=(n, n, n))
+    mesh = generate_regular_rect_mesh(a=(0, 0, 0), b=(1, 1, 1), n=(n, n, n))
+    #TODO facial_adjacency_groups is not available from merge_disjoint_meshes.
+    #mesh2 = generate_regular_rect_mesh(a=(2, 2, 2), b=(3, 3, 3), n=(n, n, n))
 
-    from meshmode.mesh.processing import merge_disjoint_meshes
-    mesh = merge_disjoint_meshes([mesh1, mesh2])
+    #from meshmode.mesh.processing import merge_disjoint_meshes
+    #mesh = merge_disjoint_meshes([mesh1, mesh2])
 
     adjacency_list = np.zeros((mesh.nelements,), dtype=set)
     for elem in range(mesh.nelements):
@@ -89,10 +115,74 @@ def test_partition_boxes_mesh():
 
     from meshmode.mesh.processing import partition_mesh
     new_meshes = [
-        partition_mesh(mesh, part_per_element, i)[0] for i in range(num_parts)]
+        partition_mesh(mesh, part_per_element, i) for i in range(num_parts)]
 
     assert mesh.nelements == np.sum(
-        [new_meshes[i].nelements for i in range(num_parts)])
+        [new_meshes[i][0].nelements for i in range(num_parts)]), \
+        "part_mesh has the wrong number of elements"
+
+    assert count_tags(mesh, BTAG_ALL) == np.sum(
+        [count_tags(new_meshes[i][0], BTAG_ALL) for i in range(num_parts)]), \
+        "part_mesh has the wrong number of BTAG_ALL boundaries"
+
+    from meshmode.mesh import BTAG_PARTITION
+    num_tags = np.zeros((num_parts,))
+
+    for part_num in range(num_parts):
+        (part, part_to_global) = new_meshes[part_num]
+        for grp_num, f_groups in enumerate(part.facial_adjacency_groups):
+            f_grp = f_groups[None]
+            for idx in range(len(f_grp.elements)):
+                tag = -f_grp.neighbors[idx]
+                assert tag >= 0
+                elem = f_grp.elements[idx]
+                face = f_grp.element_faces[idx]
+                for n_part_num in range(num_parts):
+                    (n_part, n_part_to_global) = new_meshes[n_part_num]
+                    if tag & part.boundary_tag_bit(BTAG_PARTITION(n_part_num)) != 0:
+                        num_tags[n_part_num] += 1
+                        (n_part_idx, n_grp_num, n_elem, n_face) = part.\
+                            interpart_adj_groups[grp_num].get_neighbor(elem, face)
+                        assert n_part_idx == n_part_num
+                        assert (part_num, grp_num, elem, face) == n_part.\
+                                            interpart_adj_groups[n_grp_num].\
+                                            get_neighbor(n_elem, n_face),\
+                                            "InterpartitionAdj is not consistent"
+                        p_elem = part_to_global[elem]
+                        n_part_to_global = new_meshes[n_part_num][1]
+                        p_n_elem = n_part_to_global[n_elem]
+                        p_grp_num = 0
+                        while p_elem >= mesh.groups[p_grp_num].nelements:
+                            p_elem -= mesh.groups[p_grp_num].nelements
+                            p_grp_num += 1
+                        #p_elem_base = mesh.groups[p_grp_num].element_num_base
+                        f_groups = mesh.facial_adjacency_groups[p_grp_num]
+                        for _, p_bnd_adj in f_groups.items():
+                            for idx in range(len(p_bnd_adj.elements)):
+                                if (p_elem == p_bnd_adj.elements[idx] and
+                                         face == p_bnd_adj.element_faces[idx]):
+                                    assert p_n_elem == p_bnd_adj.neighbors[idx],\
+                                            "Tag does not give correct neighbor"
+                                    assert n_face == p_bnd_adj.neighbor_faces[idx],\
+                                            "Tag does not give correct neighbor"
+
+    for tag_num in range(num_parts):
+        tag_sum = 0
+        for mesh, _ in new_meshes:
+            tag_sum += count_tags(mesh, BTAG_PARTITION(tag_num))
+        assert num_tags[tag_num] == tag_sum,\
+                "part_mesh has the wrong number of BTAG_PARTITION boundaries"
+
+
+def count_tags(mesh, tag):
+    num_bnds = 0
+    for adj_dict in mesh.facial_adjacency_groups:
+        for _, bdry_group in adj_dict.items():
+            for neighbors in bdry_group.neighbors:
+                if neighbors < 0:
+                    if -neighbors & mesh.boundary_tag_bit(tag) != 0:
+                        num_bnds += 1
+    return num_bnds
 
 # }}}