diff --git a/meshmode/discretization/connection/__init__.py b/meshmode/discretization/connection/__init__.py
index 8ea924efd1d7a5235747745d8e99f54a5a51a612..ba15835834a1dd7afbc53bc8c35ff258d7778e42 100644
--- a/meshmode/discretization/connection/__init__.py
+++ b/meshmode/discretization/connection/__init__.py
@@ -36,6 +36,9 @@ from meshmode.discretization.connection.face import (
         make_face_restriction, make_face_to_all_faces_embedding)
 from meshmode.discretization.connection.opposite_face import \
         make_opposite_face_connection
+from meshmode.discretization.connection.refinement import \
+        make_refinement_connection
+
 
 import logging
 logger = logging.getLogger(__name__)
@@ -47,7 +50,8 @@ __all__ = [
         "FRESTR_INTERIOR_FACES", "FRESTR_ALL_FACES",
         "make_face_restriction",
         "make_face_to_all_faces_embedding",
-        "make_opposite_face_connection"
+        "make_opposite_face_connection",
+        "make_refinement_connection"
         ]
 
 __doc__ = """
@@ -62,6 +66,8 @@ __doc__ = """
 
 .. autofunction:: make_opposite_face_connection
 
+.. autofunction:: make_refinement_connection
+
 Implementation details
 ^^^^^^^^^^^^^^^^^^^^^^
 
@@ -434,11 +440,4 @@ def check_connection(connection):
 # }}}
 
 
-# {{{ refinement connection
-
-def make_refinement_connection(refiner, coarse_discr):
-    pass
-
-# }}}
-
 # vim: foldmethod=marker
diff --git a/meshmode/discretization/connection/refinement.py b/meshmode/discretization/connection/refinement.py
new file mode 100644
index 0000000000000000000000000000000000000000..75604ec05dd945cde45e0e6cdd55724f88083c3f
--- /dev/null
+++ b/meshmode/discretization/connection/refinement.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+from __future__ import division, print_function, absolute_import
+
+__copyright__ = "Copyright (C) 2016 Matt Wala"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import numpy as np
+import pyopencl as cl
+import pyopencl.array  # noqa
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+# {{{ Map unit nodes to children
+
+def _map_unit_nodes_to_children(unit_nodes, tesselation):
+    """
+    Given a collection of unit nodes, return the coordinates of the
+    unit nodes mapped onto each of the children of the reference
+    element.
+
+    The tesselation should follow the format of
+    :func:`meshmode.mesh.tesselate.tesselatetri()` or
+    :func:`meshmode.mesh.tesselate.tesselatetet()`.
+
+    `unit_nodes` should be relative to the unit simplex coordinates in
+    :module:`modepy`.
+
+    :arg unit_nodes: shaped `(dim, nunit_nodes)`
+    :arg tesselation: With attributes `ref_vertices`, `children`
+    """
+    ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float)
+
+    for child_element in tesselation.children:
+        center = np.vstack(ref_vertices[child_element[0]])
+        # Scale by 1/2 since sides in the tesselation have length 2.
+        aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2
+        # (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices.
+        # Hence the translation by +/- 1.
+        yield aff_mat.dot(unit_nodes + 1) + center - 1
+
+# }}}
+
+
+# {{{ Build interpolation batches for group
+
+def _build_interpolation_batches_for_group(
+        queue, group_idx, coarse_discr_group, fine_discr_group, record):
+    """
+    To map between discretizations, we sort each of the fine mesh
+    elements into an interpolation batch.  Which batch they go
+    into is determined by where the refined unit nodes live
+    relative to the coarse reference element.
+
+    For instance, consider the following refinement:
+
+     ______      ______
+    |\     |    |\    e|
+    | \    |    |d\    |
+    |  \   |    |__\   |
+    |   \  | => |\c|\  |
+    |    \ |    |a\|b\ |
+    |     \|    |  |  \|
+     ‾‾‾‾‾‾      ‾‾‾‾‾‾
+
+    Here, the discretization unit nodes for elements a,b,c,d,e
+    will each have different positions relative to the reference
+    element, so each element gets its own batch. On the other
+    hand, for
+
+     ______      ______
+    |\     |    |\ f|\e|
+    | \    |    |d\ |g\|
+    |  \   |    |__\|__|
+    |   \  | => |\c|\  |
+    |    \ |    |a\|b\h|
+    |     \|    |  |  \|
+     ‾‾‾‾‾‾      ‾‾‾‾‾‾
+
+    the pairs {a,e}, {b,f}, {c,g}, {d,h} can share interpolation
+    batches because their unit nodes are mapped from the same part
+    of the reference element.
+    """
+    from meshmode.discretization.connection import InterpolationBatch
+
+    num_children = len(record.tesselation.children) \
+                   if record.tesselation else 0
+    from_bins = [[] for i in range(1 + num_children)]
+    to_bins = [[] for i in range(1 + num_children)]
+    for elt_idx, refinement_result in enumerate(record.element_mapping):
+        if len(refinement_result) == 1:
+            # Not refined -> interpolates to self
+            from_bins[0].append(elt_idx)
+            to_bins[0].append(refinement_result[0])
+        else:
+            assert len(refinement_result) == num_children
+            # Refined -> interpolates to children
+            for from_bin, to_bin, child_idx in zip(
+                    from_bins[1:], to_bins[1:], refinement_result):
+                from_bin.append(elt_idx)
+                to_bin.append(child_idx)
+
+    fine_unit_nodes = fine_discr_group.unit_nodes
+    mapped_unit_nodes = _map_unit_nodes_to_children(
+        fine_unit_nodes, record.tesselation)
+
+    from itertools import chain
+    for from_bin, to_bin, unit_nodes in zip(
+            from_bins, to_bins,
+            chain([fine_unit_nodes], mapped_unit_nodes)):
+        if not from_bin:
+            continue
+        yield InterpolationBatch(
+            from_group_index=group_idx,
+            from_element_indices=cl.array.to_device(queue, np.asarray(from_bin)),
+            to_element_indices=cl.array.to_device(queue, np.asarray(to_bin)),
+            result_unit_nodes=unit_nodes,
+            to_element_face=None)
+
+# }}}
+
+
+def make_refinement_connection(refiner, coarse_discr, group_factory):
+    """Return a
+    :class:`meshmode.discretization.connection.DiscretizationConnection`
+    connecting `coarse_discr` to a discretization on the fine mesh.
+
+    :arg refiner: An instance of
+        :class:`meshmode.mesh.refinement.Refiner`
+
+    :arg coarse_discr: An instance of
+        :class:`meshmode.mesh.discretization.Discretization` associated
+        with the mesh given to the refiner
+
+    :arg group_factory: An instance of
+        :class:`meshmode.mesh.discretization.ElementGroupFactory`. Used for
+        discretizing the fine mesh.
+    """
+    from meshmode.discretization.connection import (
+        DiscretizationConnectionElementGroup, DiscretizationConnection)
+
+    coarse_mesh = refiner.get_previous_mesh()
+    fine_mesh = refiner.last_mesh
+    assert coarse_discr.mesh is coarse_mesh
+
+    from meshmode.discretization import Discretization
+    fine_discr = Discretization(
+        coarse_discr.cl_context,
+        fine_mesh,
+        group_factory,
+        real_dtype=coarse_discr.real_dtype)
+
+    logger.info("building refinement connection: start")
+
+    groups = []
+    with cl.CommandQueue(fine_discr.cl_context) as queue:
+        for group_idx, (coarse_discr_group, fine_discr_group, record) in \
+                enumerate(zip(coarse_discr.groups, fine_discr.groups,
+                              refiner.group_refinement_records)):
+            groups.append(
+                DiscretizationConnectionElementGroup(
+                    list(_build_interpolation_batches_for_group(
+                            queue, group_idx, coarse_discr_group,
+                            fine_discr_group, record))))
+
+    logger.info("building refinement connection: done")
+
+    return DiscretizationConnection(
+        from_discr=coarse_discr,
+        to_discr=fine_discr,
+        groups=groups,
+        is_surjective=True)
+
+
+# vim: foldmethod=marker
diff --git a/meshmode/mesh/refinement/__init__.py b/meshmode/mesh/refinement/__init__.py
index 0430ada9929bb270e230bf183d324141e890804c..2154d45195412c6c012a54ace9b8fbcb07561055 100644
--- a/meshmode/mesh/refinement/__init__.py
+++ b/meshmode/mesh/refinement/__init__.py
@@ -24,6 +24,7 @@ THE SOFTWARE.
 import numpy as np
 import itertools
 from six.moves import range
+from pytools import RecordWithoutPickling
 
 from meshmode.mesh.generation import make_group_from_vertices
 
@@ -57,18 +58,31 @@ class TreeRayNode(object):
 
 class Refiner(object):
 
+    class _Tesselation(RecordWithoutPickling):
+
+        def __init__(self, children, ref_vertices):
+            RecordWithoutPickling.__init__(self,
+                ref_vertices=ref_vertices, children=children)
+
+    class _GroupRefinementRecord(RecordWithoutPickling):
+
+        def __init__(self, tesselation, element_mapping):
+            RecordWithoutPickling.__init__(self,
+                tesselation=tesselation, element_mapping=element_mapping)
+
     # {{{ constructor
 
     def __init__(self, mesh):
         from meshmode.mesh.tesselate import tesselatetet, tesselatetri
         self.lazy = False
         self.seen_tuple = {}
+        self.group_refinement_records = []
         tri_node_tuples, tri_result = tesselatetri()
         tet_node_tuples, tet_result = tesselatetet()
-        print(tri_node_tuples, tri_result)
         #quadrilateral_node_tuples = [
         #print tri_result, tet_result
         self.simplex_node_tuples = [None, None, tri_node_tuples, tet_node_tuples]
+        # Dimension-parameterized tesselations for refinement
         self.simplex_result = [None, None, tri_result, tet_result]
         #print tri_node_tuples, tri_result
         #self.simplex_node_tuples, self.simplex_result = tesselatetet()
@@ -89,6 +103,8 @@ class Refiner(object):
         for i in range(nvertices):
             self.hanging_vertex_element.append([])
 
+        # Fill pair_map.
+        # Add adjacency information to each TreeRayNode.
         for grp in mesh.groups:
             iel_base = grp.element_nr_base
             for iel_grp in range(grp.nelements):
@@ -195,6 +211,9 @@ class Refiner(object):
                 self.last_mesh.nelements - self.get_refine_base_index(),
                 np.bool)
 
+    def get_previous_mesh(self):
+        return self.previous_mesh
+
     def get_current_mesh(self):
 
         from meshmode.mesh import Mesh
@@ -417,6 +436,8 @@ class Refiner(object):
 #                for next_vertices in next_vertices_list:
 #                    remove_element_from_connectivity(next_vertices, new_hanging_vertex_elements, to_remove)
 
+        # {{{ Add element to connectivity
+
         def add_element_to_connectivity(vertices, new_hanging_vertex_elements, to_add):
             if len(vertices) == 2:
                 min_vertex = min(vertices[0], vertices[1])
@@ -502,14 +523,22 @@ class Refiner(object):
 #                        return
 #            add_element_to_connectivity(next_element_rays, new_hanging_vertex_elements, to_add)
 
+        # }}}
+
+        # {{{ Add hanging vertex element
+
         def add_hanging_vertex_el(v_index, el):
             assert not (v_index == 37 and el == 48)
 
             new_hanging_vertex_element[v_index].append(el)
 
+        # }}}
+
 #        def remove_ray_el(ray, el):
 #            ray.remove(el)
 
+        # {{{ Check adjacent elements
+
         def check_adjacent_elements(groups, new_hanging_vertex_elements, nelements_in_grp):
             for grp in groups:
                 iel_base = 0
@@ -529,6 +558,8 @@ class Refiner(object):
                             assert((iel_base+iel_grp) in new_hanging_vertex_elements[cur_node.left_vertex])
                             assert((iel_base+iel_grp) in new_hanging_vertex_elements[cur_node.right_vertex])
 
+        # }}}
+
         for i in range(len(self.last_mesh.vertices)):
             for j in range(len(self.last_mesh.vertices[i])):
                 vertices[i,j] = self.last_mesh.vertices[i,j]
@@ -545,14 +576,24 @@ class Refiner(object):
         grpn = 0
         vertices_index = len(self.last_mesh.vertices[0])
         nelements_in_grp = grp.nelements
-        for grp in self.last_mesh.groups:
+        del self.group_refinement_records[:]
+
+        for grp_idx, grp in enumerate(self.last_mesh.groups):
             iel_base = grp.element_nr_base
+            # List of lists mapping element number to new element number(s).
+            element_mapping = []
+            tesselation = None
+
             for iel_grp in range(grp.nelements):
+                element_mapping.append([iel_grp])
                 if refine_flags[iel_base+iel_grp]:
                     midpoint_vertices = []
                     vertex_indices = grp.vertex_indices[iel_grp]
                     #if simplex
                     if len(grp.vertex_indices[iel_grp]) == len(self.last_mesh.vertices)+1:
+
+                        # {{{ Get midpoints for all pairs of vertices
+
                         for i in range(len(vertex_indices)):
                             for j in range(i+1, len(vertex_indices)):
                                 min_index = min(vertex_indices[i], vertex_indices[j])
@@ -581,6 +622,7 @@ class Refiner(object):
                                     cur_midpoint = cur_node.midpoint
                                     midpoint_vertices.append(cur_midpoint)
 
+                        # }}}
 
                         #generate new rays
                         cur_dim = len(grp.vertex_indices[0])-1
@@ -598,13 +640,16 @@ class Refiner(object):
                         for midpoint_index, midpoint_tuple in enumerate(self.index_to_midpoint_tuple[cur_dim]):
                             node_tuple_to_coord[midpoint_tuple] = midpoint_vertices[midpoint_index]
                         for i in range(len(self.simplex_result[cur_dim])):
+                            if i == 0:
+                                iel = iel_grp
+                            else:
+                                iel = nelements_in_grp + i - 1
+                                element_mapping[-1].append(iel)
                             for j in range(len(self.simplex_result[cur_dim][i])):
-                                if i == 0:
-                                    groups[grpn][iel_grp][j] = \
-                                            node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
-                                else:
-                                    groups[grpn][nelements_in_grp+i-1][j] = \
-                                            node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
+                                groups[grpn][iel][j] = \
+                                    node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
+                        tesselation = self._Tesselation(
+                            self.simplex_result[cur_dim], self.simplex_node_tuples[cur_dim])
                         nelements_in_grp += len(self.simplex_result[cur_dim])-1
                     #assuming quad otherwise
                     #else:
@@ -615,6 +660,8 @@ class Refiner(object):
 #                        def generate_all_tuples(cur_list):
 #                            if len(cur_list[len(cur_list)-1])
 
+            self.group_refinement_records.append(
+                self._GroupRefinementRecord(tesselation, element_mapping))
 
         #clear connectivity data
         for grp in self.last_mesh.groups:
@@ -653,6 +700,7 @@ class Refiner(object):
 
         from meshmode.mesh import Mesh
 
+        self.previous_mesh = self.last_mesh
         self.last_mesh = Mesh(
                 vertices, grp,
                 nodal_adjacency=self.generate_nodal_adjacency(
diff --git a/test/test_refinement.py b/test/test_refinement.py
index aa1ab68a725c73c8618e32fa7c7a8dd518a4420b..9c36b9750ebe87a899a8dd5c5f17f4191ed1b0b4 100644
--- a/test/test_refinement.py
+++ b/test/test_refinement.py
@@ -23,6 +23,8 @@ THE SOFTWARE.
 """
 
 import pytest
+import pyopencl as cl
+import pyopencl.clmath  # noqa
 
 import numpy as np
 from pyopencl.tools import (  # noqa
@@ -33,19 +35,25 @@ from meshmode.mesh.generation import (  # noqa
 from meshmode.mesh.refinement.utils import check_nodal_adj_against_geometry
 from meshmode.mesh.refinement import Refiner
 
+from meshmode.discretization.poly_element import (
+    InterpolatoryQuadratureSimplexGroupFactory,
+    PolynomialWarpAndBlendGroupFactory,
+    PolynomialEquidistantGroupFactory,
+)
+
 import logging
 logger = logging.getLogger(__name__)
 
 from functools import partial
 
 
-def gen_blob_mesh():
+def gen_blob_mesh(h=0.2):
     from meshmode.mesh.io import generate_gmsh, FileSource
     return generate_gmsh(
             FileSource("blob-2d.step"), 2, order=1,
             force_ambient_dim=2,
             other_options=[
-                "-string", "Mesh.CharacteristicLengthMax = %s;" % 0.2]
+                "-string", "Mesh.CharacteristicLengthMax = %s;" % h]
             )
 
 
@@ -124,6 +132,84 @@ def test_refinement(case_name, mesh_gen, flag_gen, num_generations):
         check_nodal_adj_against_geometry(mesh)
 
 
+@pytest.mark.parametrize("group_factory", [
+    InterpolatoryQuadratureSimplexGroupFactory,
+    PolynomialWarpAndBlendGroupFactory,
+    PolynomialEquidistantGroupFactory
+    ])
+@pytest.mark.parametrize(("mesh_name", "dim", "mesh_pars"), [
+    ("blob", 2, [1e-1, 8e-2, 5e-2]),
+    ("warp", 2, [4, 5, 6]),
+    ("warp", 3, [4, 5, 6]),
+])
+@pytest.mark.parametrize("refine_flags", [
+    # FIXME: slow
+    # uniform_refine_flags,
+    partial(random_refine_flags, 0.4)
+])
+def test_refinement_connection(
+        ctx_getter, group_factory, mesh_name, dim, mesh_pars, refine_flags):
+    from random import seed
+    seed(13)
+
+    cl_ctx = ctx_getter()
+    queue = cl.CommandQueue(cl_ctx)
+
+    from meshmode.discretization import Discretization
+    from meshmode.discretization.connection import (
+            make_refinement_connection, check_connection)
+
+    from pytools.convergence import EOCRecorder
+    eoc_rec = EOCRecorder()
+
+    order = 5
+
+    def f(x):
+        from six.moves import reduce
+        return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x)
+
+    for mesh_par in mesh_pars:
+        # {{{ get mesh
+
+        if mesh_name == "blob":
+            assert dim == 2
+            h = mesh_par
+            mesh = gen_blob_mesh(h)
+        elif mesh_name == "warp":
+            from meshmode.mesh.generation import generate_warped_rect_mesh
+            mesh = generate_warped_rect_mesh(dim, order=1, n=mesh_par)
+            h = 1/mesh_par
+        else:
+            raise ValueError("mesh_name not recognized")
+
+        # }}}
+
+        discr = Discretization(cl_ctx, mesh, group_factory(order))
+
+        refiner = Refiner(mesh)
+        refiner.refine(refine_flags(mesh))
+        connection = make_refinement_connection(
+            refiner, discr, group_factory(order))
+        check_connection(connection)
+
+        fine_discr = connection.to_discr
+
+        x = discr.nodes().with_queue(queue)
+        x_fine = fine_discr.nodes().with_queue(queue)
+        f_coarse = f(x)
+        f_interp = connection(queue, f_coarse).with_queue(queue)
+        f_true = f(x_fine).with_queue(queue)
+
+        import numpy.linalg as la
+        err = la.norm((f_interp - f_true).get(queue), np.inf)
+        eoc_rec.add_data_point(h, err)
+
+    print(eoc_rec)
+    assert (
+            eoc_rec.order_estimate() >= order-0.5
+            or eoc_rec.max_error() < 1e-14)
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1: