diff --git a/meshmode/discretization/connection/refinement.py b/meshmode/discretization/connection/refinement.py
index 5ef91fe7e41dc464061e73438796afa174756560..ff009fce6efe7833e7a8bbd43c6d33745e722704 100644
--- a/meshmode/discretization/connection/refinement.py
+++ b/meshmode/discretization/connection/refinement.py
@@ -31,39 +31,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-# {{{ Map unit nodes to children
-
-def _map_unit_nodes_to_children(unit_nodes, tesselation):
-    """
-    Given a collection of unit nodes, return the coordinates of the
-    unit nodes mapped onto each of the children of the reference
-    element.
-
-    The tesselation should follow the format of
-    :func:`meshmode.mesh.tesselate.tesselatetri()` or
-    :func:`meshmode.mesh.tesselate.tesselatetet()`.
-
-    `unit_nodes` should be relative to the unit simplex coordinates in
-    :module:`modepy`.
-
-    :arg unit_nodes: shaped `(dim, nunit_nodes)`
-    :arg tesselation: With attributes `ref_vertices`, `children`
-    """
-    ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float)
-
-    assert len(unit_nodes.shape) == 2
-
-    for child_element in tesselation.children:
-        center = np.vstack(ref_vertices[child_element[0]])
-        # Scale by 1/2 since sides in the tesselation have length 2.
-        aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2
-        # (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices.
-        # Hence the translation by +/- 1.
-        yield aff_mat.dot(unit_nodes + 1) + center - 1
-
-# }}}
-
-
 # {{{ Build interpolation batches for group
 
 def _build_interpolation_batches_for_group(
@@ -123,7 +90,8 @@ def _build_interpolation_batches_for_group(
                 to_bin.append(child_idx)
 
     fine_unit_nodes = fine_discr_group.unit_nodes
-    mapped_unit_nodes = _map_unit_nodes_to_children(
+    from meshmode.mesh.refinement.utils import map_unit_nodes_to_children
+    mapped_unit_nodes = map_unit_nodes_to_children(
         fine_unit_nodes, record.tesselation)
 
     from itertools import chain
diff --git a/meshmode/mesh/generation.py b/meshmode/mesh/generation.py
index ce2a5f236d0ec524a411c78c87b7904abd423db6..ff5d8581d5d74291e48ecafe5936548b290e0960 100644
--- a/meshmode/mesh/generation.py
+++ b/meshmode/mesh/generation.py
@@ -547,7 +547,7 @@ def generate_warped_rect_mesh(dim, order, n):
                 0.05*np.cos(10*x[0])
                 + 1.3*x[1] + np.sin(x[1]))
         if len(x) == 3:
-            result[2] = x[2] + np.sin(x[0])
+            result[2] = x[2] + np.sin(x[0] / 2) / 2
         return result
 
     from meshmode.mesh.processing import map_mesh
diff --git a/meshmode/mesh/refinement/__init__.py b/meshmode/mesh/refinement/__init__.py
index 16639442bdbaf10d6dd2def5c87bd8ff3254c25a..3c39dde0a25ba950ae94cf063a3d87a56c01d9b9 100644
--- a/meshmode/mesh/refinement/__init__.py
+++ b/meshmode/mesh/refinement/__init__.py
@@ -26,8 +26,6 @@ import itertools
 from six.moves import range
 from pytools import RecordWithoutPickling
 
-from meshmode.mesh.generation import make_group_from_vertices
-
 
 class TreeRayNode(object):
     """Describes a ray as a tree, this class represents each node in this tree
@@ -217,31 +215,6 @@ class Refiner(object):
         return self.previous_mesh
 
     def get_current_mesh(self):
-
-        from meshmode.mesh import Mesh
-        #return Mesh(vertices, [grp], nodal_adjacency=self.generate_nodal_adjacency(len(self.last_mesh.groups[group].vertex_indices) \
-        #            + count*3))
-        groups = []
-        grpn = 0
-        for grp in self.last_mesh.groups:
-            groups.append(np.empty([len(grp.vertex_indices),
-                len(self.last_mesh.groups[grpn].vertex_indices[0])], dtype=np.int32))
-            for iel_grp in range(grp.nelements):
-                for i in range(0, len(grp.vertex_indices[iel_grp])):
-                    groups[grpn][iel_grp][i] = grp.vertex_indices[iel_grp][i]
-            grpn += 1
-        grp = []
-
-        for grpn in range(0, len(groups)):
-            grp.append(make_group_from_vertices(self.last_mesh.vertices, groups[grpn], 4))
-        self.last_mesh = Mesh(
-                self.last_mesh.vertices, grp,
-                nodal_adjacency=self.generate_nodal_adjacency(
-                    len(self.last_mesh.groups[0].vertex_indices),
-                    len(self.last_mesh.vertices[0])),
-                vertex_id_dtype=self.last_mesh.vertex_id_dtype,
-                element_id_dtype=self.last_mesh.element_id_dtype)
-
         return self.last_mesh
 
     def get_leaves(self, cur_node):
@@ -590,15 +563,42 @@ class Refiner(object):
             element_mapping = []
             tesselation = None
 
+            # {{{ get midpoint coordinates for vertices
+
+            midpoints_to_find = []
+            resampler = None
+            for iel_grp in range(grp.nelements):
+                if refine_flags[iel_base + iel_grp]:
+                    # if simplex
+                    if len(grp.vertex_indices[iel_grp]) == grp.dim + 1:
+                        midpoints_to_find.append(iel_grp)
+                        if not resampler:
+                            from meshmode.mesh.refinement.resampler import (
+                                SimplexResampler)
+                            resampler = SimplexResampler()
+                            tesselation = self._Tesselation(
+                                self.simplex_result[grp.dim],
+                                self.simplex_node_tuples[grp.dim])
+                    else:
+                        raise NotImplementedError("unimplemented: midpoint finding"
+                                                  "for non simplex elements")
+
+            if midpoints_to_find:
+                midpoints = resampler.get_midpoints(
+                        grp, tesselation, midpoints_to_find)
+                midpoint_order = resampler.get_vertex_pair_to_midpoint_order(grp.dim)
+
+            del midpoints_to_find
+
+            # }}}
+
             for iel_grp in range(grp.nelements):
                 element_mapping.append([iel_grp])
                 if refine_flags[iel_base+iel_grp]:
                     midpoint_vertices = []
                     vertex_indices = grp.vertex_indices[iel_grp]
-                    #if simplex
+                    # if simplex
                     if len(vertex_indices) == grp.dim + 1:
-                        # {{{ Get midpoints for all pairs of vertices
-
                         for i in range(len(vertex_indices)):
                             for j in range(i+1, len(vertex_indices)):
                                 min_index = min(vertex_indices[i], vertex_indices[j])
@@ -617,18 +617,15 @@ class Refiner(object):
                                     vertex_pair2 = (max_index, vertices_index)
                                     self.pair_map[vertex_pair1] = cur_node.left
                                     self.pair_map[vertex_pair2] = cur_node.right
-                                    for k in range(len(self.last_mesh.vertices)):
-                                        vertices[k, vertices_index] = \
-                                        (self.last_mesh.vertices[k, vertex_indices[i]] +
-                                        self.last_mesh.vertices[k, vertex_indices[j]]) / 2.0
+                                    midpoint_idx = midpoint_order[(i, j)]
+                                    vertices[:, vertices_index] = \
+                                            midpoints[iel_grp][:, midpoint_idx]
                                     midpoint_vertices.append(vertices_index)
                                     vertices_index += 1
                                 else:
                                     cur_midpoint = cur_node.midpoint
                                     midpoint_vertices.append(cur_midpoint)
 
-                        # }}}
-
                         #generate new rays
                         cur_dim = len(grp.vertex_indices[0])-1
                         for i in range(len(midpoint_vertices)):
@@ -653,12 +650,12 @@ class Refiner(object):
                             for j in range(len(self.simplex_result[cur_dim][i])):
                                 groups[grpn][iel][j] = \
                                     node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
-                        tesselation = self._Tesselation(
-                            self.simplex_result[cur_dim], self.simplex_node_tuples[cur_dim])
                         nelements_in_grp += len(self.simplex_result[cur_dim])-1
                     #assuming quad otherwise
-                    #else:
+                    else:
                         #quadrilateral
+                        raise NotImplementedError("unimplemented: "
+                                                  "support for quad elements")
 #                        node_tuple_to_coord = {}
 #                        for node_index, node_tuple in enumerate(self.index_to_node_tuple[cur_dim]):
 #                            node_tuple_to_coord[node_tuple] = grp.vertex_indices[iel_grp][node_index]
@@ -699,15 +696,64 @@ class Refiner(object):
         #check_adjacent_elements(groups, new_hanging_vertex_element, nelements_in_grp)
 
         self.hanging_vertex_element = new_hanging_vertex_element
-        grp = []
-        for grpn in range(0, len(groups)):
-            grp.append(make_group_from_vertices(vertices, groups[grpn], 4))
+
+        # {{{ make new groups
+
+        new_mesh_el_groups = []
+
+        for refinement_record, group, prev_group in zip(
+                self.group_refinement_records, groups, self.last_mesh.groups):
+            is_simplex = len(prev_group.vertex_indices[0]) == prev_group.dim + 1
+            ambient_dim = len(prev_group.nodes)
+            nelements = len(group)
+            nunit_nodes = len(prev_group.unit_nodes[0])
+
+            nodes = np.empty(
+                (ambient_dim, nelements, nunit_nodes),
+                dtype=prev_group.nodes.dtype)
+
+            element_mapping = refinement_record.element_mapping
+
+            to_resample = [elem for elem in range(len(element_mapping))
+                           if len(element_mapping[elem]) > 1]
+
+            if to_resample:
+                # if simplex
+                if is_simplex:
+                    from meshmode.mesh.refinement.resampler import SimplexResampler
+                    resampler = SimplexResampler()
+                    new_nodes = resampler.get_tesselated_nodes(
+                        prev_group, refinement_record.tesselation, to_resample)
+                else:
+                    raise NotImplementedError(
+                        "unimplemented: node resampling for non simplex elements")
+
+            for elem, mapped_elems in enumerate(element_mapping):
+                if len(mapped_elems) == 1:
+                    # No resampling required, just copy over
+                    nodes[:, mapped_elems[0]] = prev_group.nodes[:, elem]
+                    n = nodes[:, mapped_elems[0]]
+                else:
+                    nodes[:, mapped_elems] = new_nodes[elem]
+
+            if is_simplex:
+                new_mesh_el_groups.append(
+                    type(prev_group)(
+                        order=prev_group.order,
+                        vertex_indices=group,
+                        nodes=nodes,
+                        unit_nodes=prev_group.unit_nodes))
+            else:
+                raise NotImplementedError("unimplemented: support for creating"
+                                          "non simplex element groups")
+
+        # }}}
 
         from meshmode.mesh import Mesh
 
         self.previous_mesh = self.last_mesh
         self.last_mesh = Mesh(
-                vertices, grp,
+                vertices, new_mesh_el_groups,
                 nodal_adjacency=self.generate_nodal_adjacency(
                     totalnelements, nvertices, groups),
                 vertex_id_dtype=self.last_mesh.vertex_id_dtype,
diff --git a/meshmode/mesh/refinement/resampler.py b/meshmode/mesh/refinement/resampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea22b3aa3f7c36b23f823e06646bb27e2723110
--- /dev/null
+++ b/meshmode/mesh/refinement/resampler.py
@@ -0,0 +1,134 @@
+from __future__ import division, absolute_import, print_function
+
+__copyright__ = "Copyright (C) 2016 Matt Wala"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+import numpy as np
+import modepy as mp
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+# {{{ resampling simplex points for refinement
+
+# NOTE: Class internal to refiner: do not make documentation public.
+class SimplexResampler(object):
+    """
+    Resampling of points on simplex elements for refinement.
+
+    Most methods take a ``tesselation`` parameter.
+    The tesselation should follow the format of
+    :func:`meshmode.mesh.tesselate.tesselatetri()` or
+    :func:`meshmode.mesh.tesselate.tesselatetet()`.
+    """
+
+    def get_vertex_pair_to_midpoint_order(self, dim):
+        """
+        :arg dim: Dimension of the element
+
+        :return: A :class:`dict` mapping the vertex pair :math:`(v1, v2)` (with
+            :math:`v1 < v2`) to the number of the midpoint in the tesselation
+            ordering (the numbering is restricted to the midpoints, so there
+            are no gaps in the numbering)
+        """
+        nmidpoints = dim * (dim + 1) // 2
+        return dict(zip(
+            ((i, j) for j in range(dim + 1) for i in range(j)),
+            range(nmidpoints)
+            ))
+
+    def get_midpoints(self, group, tesselation, elements):
+        """
+        Compute the midpoints of the vertices of the specified elements.
+
+        :arg group: An instance of :class:`meshmode.mesh.SimplexElementGroup`
+        :arg tesselation: With attributes `ref_vertices`, `children`
+        :arg elements: A list of (group-relative) element numbers
+
+        :return: A :class:`dict` mapping element numbers to midpoint
+            coordinates, with each value in the map having shape
+            ``(ambient_dim, nmidpoints)``. The ordering of the midpoints
+            follows their ordering in the tesselation (see also
+            :meth:`SimplexResampler.get_vertex_pair_to_midpoint_order`)
+        """
+        assert len(group.vertex_indices[0]) == group.dim + 1
+
+        # Get midpoints, converted to unit coordinates.
+        midpoints = -1 + np.array([vertex for vertex in
+                tesselation.ref_vertices if 1 in vertex], dtype=float)
+
+        resamp_mat = mp.resampling_matrix(
+            mp.simplex_best_available_basis(group.dim, group.order),
+            midpoints.T,
+            group.unit_nodes)
+
+        resamp_midpoints = np.einsum("mu,deu->edm",
+                                     resamp_mat,
+                                     group.nodes[:, elements])
+
+        return dict(zip(elements, resamp_midpoints))
+
+    def get_tesselated_nodes(self, group, tesselation, elements):
+        """
+        Compute the nodes of the child elements according to the tesselation.
+
+        :arg group: An instance of :class:`meshmode.mesh.SimplexElementGroup`
+        :arg tesselation: With attributes `ref_vertices`, `children`
+        :arg elements: A list of (group-relative) element numbers
+
+        :return: A :class:`dict` mapping element numbers to node
+            coordinates, with each value in the map having shape
+            ``(ambient_dim, nchildren, nunit_nodes)``.
+            The ordering of the child nodes follows the ordering
+            of ``tesselation.children.``
+        """
+        assert len(group.vertex_indices[0]) == group.dim + 1
+
+        from meshmode.mesh.refinement.utils import map_unit_nodes_to_children
+
+        # Get child unit node coordinates.
+        child_unit_nodes = np.hstack(list(
+            map_unit_nodes_to_children(group.unit_nodes, tesselation)))
+
+        resamp_mat = mp.resampling_matrix(
+            mp.simplex_best_available_basis(group.dim, group.order),
+            child_unit_nodes,
+            group.unit_nodes)
+
+        resamp_unit_nodes = np.einsum("cu,deu->edc",
+                                      resamp_mat,
+                                      group.nodes[:, elements])
+
+        ambient_dim = len(group.nodes)
+        nunit_nodes = len(group.unit_nodes[0])
+
+        return dict((elem,
+            resamp_unit_nodes[ielem].reshape(
+                 (ambient_dim, -1, nunit_nodes)))
+            for ielem, elem in enumerate(elements))
+
+# }}}
+
+
+# vim: foldmethod=marker
diff --git a/meshmode/mesh/refinement/utils.py b/meshmode/mesh/refinement/utils.py
index 751142aa6d67ff7717b017ebe83f13a18e216ef3..08b71f8755f4f270756c904d8d5a592e753cb30c 100644
--- a/meshmode/mesh/refinement/utils.py
+++ b/meshmode/mesh/refinement/utils.py
@@ -29,8 +29,42 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+# {{{ map unit nodes to children
+
+
+def map_unit_nodes_to_children(unit_nodes, tesselation):
+    """
+    Given a collection of unit nodes, return the coordinates of the
+    unit nodes mapped onto each of the children of the reference
+    element.
+
+    The tesselation should follow the format of
+    :func:`meshmode.mesh.tesselate.tesselatetri()` or
+    :func:`meshmode.mesh.tesselate.tesselatetet()`.
+
+    `unit_nodes` should be relative to the unit simplex coordinates in
+    :module:`modepy`.
+
+    :arg unit_nodes: shaped `(dim, nunit_nodes)`
+    :arg tesselation: With attributes `ref_vertices`, `children`
+    """
+    ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float)
+
+    assert len(unit_nodes.shape) == 2
+
+    for child_element in tesselation.children:
+        center = np.vstack(ref_vertices[child_element[0]])
+        # Scale by 1/2 since sides in the tesselation have length 2.
+        aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2
+        # (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices.
+        # Hence the translation by +/- 1.
+        yield aff_mat.dot(unit_nodes + 1) + center - 1
+
+# }}}
+
 # {{{ test nodal adjacency against geometry
 
+
 def is_symmetric(relation, debug=False):
     for a, other_list in enumerate(relation):
         for b in other_list:
diff --git a/test/test_refinement.py b/test/test_refinement.py
index bed61eb94c900db0f78b526cfe43ffababe4d825..8098766d446874b08e676d304d5ac73b88e1163a 100644
--- a/test/test_refinement.py
+++ b/test/test_refinement.py
@@ -47,10 +47,10 @@ logger = logging.getLogger(__name__)
 from functools import partial
 
 
-def gen_blob_mesh(h=0.2):
+def gen_blob_mesh(h=0.2, order=1):
     from meshmode.mesh.io import generate_gmsh, FileSource
     return generate_gmsh(
-            FileSource("blob-2d.step"), 2, order=1,
+            FileSource("blob-2d.step"), 2, order=order,
             force_ambient_dim=2,
             other_options=[
                 "-string", "Mesh.CharacteristicLengthMax = %s;" % h]
@@ -152,16 +152,21 @@ def test_refinement(case_name, mesh_gen, flag_gen, num_generations):
     ("warp", 2, [4, 5, 6]),
     ("warp", 3, [4, 5, 6]),
 ])
+@pytest.mark.parametrize("mesh_order", [1, 5])
 @pytest.mark.parametrize("refine_flags", [
     # FIXME: slow
-    # uniform_refine_flags,
+    #uniform_refine_flags,
     partial(random_refine_flags, 0.4)
 ])
 def test_refinement_connection(
-        ctx_getter, group_factory, mesh_name, dim, mesh_pars, refine_flags):
+        ctx_getter, group_factory, mesh_name, dim, mesh_pars, mesh_order,
+        refine_flags, plot_mesh=False):
     from random import seed
     seed(13)
 
+    # Discretization order
+    order = 5
+
     cl_ctx = ctx_getter()
     queue = cl.CommandQueue(cl_ctx)
 
@@ -172,8 +177,6 @@ def test_refinement_connection(
     from pytools.convergence import EOCRecorder
     eoc_rec = EOCRecorder()
 
-    order = 5
-
     def f(x):
         from six.moves import reduce
         return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x)
@@ -185,14 +188,17 @@ def test_refinement_connection(
             assert dim == 1
             h = 1 / mesh_par
             mesh = make_curve_mesh(
-                partial(ellipse, 1), np.linspace(0, 1, mesh_par + 1), order=1)
+                partial(ellipse, 1), np.linspace(0, 1, mesh_par + 1),
+                order=mesh_order)
         elif mesh_name == "blob":
+            if mesh_order == 5:
+                pytest.xfail("https://gitlab.tiker.net/inducer/meshmode/issues/2")
             assert dim == 2
             h = mesh_par
-            mesh = gen_blob_mesh(h)
+            mesh = gen_blob_mesh(h, mesh_order)
         elif mesh_name == "warp":
             from meshmode.mesh.generation import generate_warped_rect_mesh
-            mesh = generate_warped_rect_mesh(dim, order=1, n=mesh_par)
+            mesh = generate_warped_rect_mesh(dim, order=mesh_order, n=mesh_par)
             h = 1/mesh_par
         else:
             raise ValueError("mesh_name not recognized")
@@ -202,7 +208,9 @@ def test_refinement_connection(
         discr = Discretization(cl_ctx, mesh, group_factory(order))
 
         refiner = Refiner(mesh)
-        refiner.refine(refine_flags(mesh))
+        flags = refine_flags(mesh)
+        refiner.refine(flags)
+
         connection = make_refinement_connection(
             refiner, discr, group_factory(order))
         check_connection(connection)
@@ -215,6 +223,18 @@ def test_refinement_connection(
         f_interp = connection(queue, f_coarse).with_queue(queue)
         f_true = f(x_fine).with_queue(queue)
 
+        if plot_mesh:
+            import matplotlib.pyplot as plt
+            x = x.get(queue)
+            err = np.array(np.log10(
+                1e-16 + np.abs((f_interp - f_true).get(queue))), dtype=float)
+            import matplotlib.cm as cm
+            cmap = cm.ScalarMappable(cmap=cm.jet)
+            cmap.set_array(err)
+            plt.scatter(x[0], x[1], c=cmap.to_rgba(err), s=20, cmap=cmap)
+            plt.colorbar(cmap)
+            plt.show()
+
         import numpy.linalg as la
         err = la.norm((f_interp - f_true).get(queue), np.inf)
         eoc_rec.add_data_point(h, err)