diff --git a/examples/plot-connectivity.py b/examples/plot-connectivity.py
index ca6b235cc88ad0e51ad1867570c95211dffbc0bd..83d76b899a92e0f59d866f198322bac2feebbab0 100644
--- a/examples/plot-connectivity.py
+++ b/examples/plot-connectivity.py
@@ -2,7 +2,7 @@ from __future__ import division
 
 import numpy as np  # noqa
 import pyopencl as cl
-
+import random
 import os
 order = 4
 
@@ -54,11 +54,14 @@ def main2():
     mesh = generate_torus(3, 1, order=order)
     from meshmode.mesh.refinement import Refiner
     r = Refiner(mesh)
-    flags = np.zeros(len(mesh.groups[0].vertex_indices))
-    for i in range(0, len(flags)):
-        if i % 2 == 0:
-            flags[i] = 1
-    mesh = r.refine(flags)
+    
+    times = random.randint(1, 1)
+    for time in xrange(times):
+        flags = np.zeros(len(mesh.groups[0].vertex_indices))
+        for i in xrange(0, len(flags)):
+            flags[i] = random.randint(0, 1)
+        mesh = r.refine(flags)
+
     from meshmode.discretization import Discretization
     from meshmode.discretization.poly_element import \
             PolynomialWarpAndBlendGroupFactory
diff --git a/meshmode/mesh/refinement.py b/meshmode/mesh/refinement.py
index 5156fd705fc41b85e9c94f61d12ad9421ebcc83a..2a563cf83591b37b77404df9b995d94e7fd2ee9a 100644
--- a/meshmode/mesh/refinement.py
+++ b/meshmode/mesh/refinement.py
@@ -37,7 +37,21 @@ class _SplitFaceRecord(object):
         integer or None
     """
 class Adj:
-    def __init__(self, vertex=None, elements=None, velements=[None, None]):
+    """One vertex-associated entry of a ray.
+
+    .. attribute:: elements
+
+        A list of numbers of elements adjacent to
+        the edge following :attr:`vertex`
+        along the ray.
+
+    .. attribute:: velements
+
+        A list of numbers of elements adjacent to
+        :attr:`vertex`.
+
+    """
+    def __init__(self, vertex=None, elements=[None, None], velements=[None, None]):
         self.vertex = vertex
         self.elements = elements
         self.velements = velements
@@ -45,6 +59,25 @@ class Adj:
         return 'vertex: ' + str(self.vertex) + ' ' + 'elements: ' + str(self.elements) + ' velements: ' + str(self.velements)
 #map pair of vertices to ray and midpoint
 class PairMap:
+    """Describes a segment of a ray between two vertices.
+
+    .. attribute:: ray
+
+        Index of the ray in the *rays* list.
+
+    .. attribute:: d
+
+        A :class:`bool` denoting direction in the ray,
+        with *True* representing "positive" and *False*
+        representing "negative".
+
+    .. attribute:: midpoint
+
+        Vertex index of the midpoint of this segment.
+
+        *None* if no midpoint has been assigned yet.
+    """
+
     def __init__(self, ray=None, d = True, midpoint=None):
         self.ray = ray
         #direction in ray, True means that second node (bigger index) is after first in ray
@@ -62,17 +95,35 @@ class Refiner(object):
         # around.
         self.last_split_elements = None
 
+        
+        # {{{ initialization
+
+        # a list of dllist instances containing Adj objects
         self.rays = []
+
+        # mapping: (vertex_1, vertex_2) -> PairMap
+        # where vertex_i represents a vertex number
+        #
+        # Assumption: vertex_1 < vertex_2
         self.pair_map = {}
+
         nvertices = len(mesh.vertices[0])
-        #dictionary of ray that a given vertex belongs to, with node in that ray
+        
+        # list of dictionaries, with each entry corresponding to
+        # one vertex.
+        # 
+        # Each dictionary maps
+        #   ray number -> dllist node containing a :class:`Adj`,
+        #                 (part of *rays*)
         self.vertex_to_ray = []
+
         for i in xrange(nvertices):
             self.vertex_to_ray.append({})
         for grp in mesh.groups:
             iel_base = grp.element_nr_base
             for iel_grp in xrange(grp.nelements):
-                for i in range(0, len(grp.vertex_indices[iel_grp])):
+                #use six.moves.range
+                for i in range(len(grp.vertex_indices[iel_grp])):
                     for j in range(i+1, len(grp.vertex_indices[iel_grp])):
                         vertex_pair = (min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]), \
                             max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]))
@@ -80,7 +131,7 @@ class Refiner(object):
                             els = []
                             els.append(iel_base+iel_grp)
                             fr = Adj(min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]), els)
-                            to = Adj(max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]), None)
+                            to = Adj(max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]), [None, None])
                             self.rays.append(dllist([fr, to]))
                             self.pair_map[vertex_pair] = PairMap(len(self.rays) - 1)
                             self.vertex_to_ray[min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][len(self.rays)-1]\
@@ -89,6 +140,9 @@ class Refiner(object):
                                 = self.rays[len(self.rays)-1].nodeat(1)
                         else:
                             self.rays[self.pair_map[vertex_pair].ray].nodeat(0).value.elements.append(iel_base+iel_grp)
+        
+        # }}}
+
         #print mesh.groups[0].vertex_indices
         '''
         self.ray_vertices = np.empty([len(mesh.groups[0].vertex_indices) * 
@@ -250,16 +304,23 @@ class Refiner(object):
                         indices_i += 1
         '''
         import numpy as np
+        # if (isinstance(self.last_mesh.groups[0], SimplexElementGroup) and 
+        #           self.last_mesh.groups[0].dim == 2):
+        print 'refining'
         if(len(self.last_mesh.groups[0].vertex_indices[0]) == 3):
             groups = []
             midpoint_already = {}
             nelements = 0
             nvertices = len(self.last_mesh.vertices[0])
             grpn = 0
-            #create np arrays for groups and vertices
+            totalnelements=0
+
+            # {{{ create new vertices array and each group's vertex_indices
+            
             for grp in self.last_mesh.groups:
                 iel_base = grp.element_nr_base
                 nelements = 0
+                #print grp.nelements
                 for iel_grp in xrange(grp.nelements):
                     nelements += 1
                     if refine_flags[iel_base+iel_grp]:
@@ -273,10 +334,17 @@ class Refiner(object):
                                     midpoint_already[vertex_pair] = True
                 groups.append(np.empty([nelements, len(self.last_mesh.groups[grpn].vertex_indices[grpn])], dtype=np.int32))
                 grpn += 1
+                totalnelements += nelements
+
             vertices = np.empty([len(self.last_mesh.vertices), nvertices])
-            #assign original vertices and elements to vertices and groups
+
+            # }}}
+
+            # {{{ bring over vertex_indices and vertices info from previous generation
+
             for i in range(0, len(self.last_mesh.vertices)):
                 for j in range(0, len(self.last_mesh.vertices[i])):
+                    # always use v[i,j]
                     vertices[i][j] = self.last_mesh.vertices[i][j]
             grpn = 0
             for grp in self.last_mesh.groups:
@@ -285,16 +353,33 @@ class Refiner(object):
                         groups[grpn][iel_grp][i] = grp.vertex_indices[iel_grp][i]
                 grpn += 1
             grpn = 0
+
+            # }}}
+
             vertices_idx = len(self.last_mesh.vertices[0])
             for grp in self.last_mesh.groups:
                 iel_base = grp.element_nr_base
                 indices_idx = len(grp.vertex_indices)
+                
+                # np.where
                 for iel_grp in xrange(grp.nelements):
                     if refine_flags[iel_base+iel_grp]:
+
+                        # {{{ split element
+
+                        # {{{ go through vertex pairs in element
+
+                        # stores indices of all midpoints for this element
+                        # (in order of vertex pairs in elements)
+                        verts = []
+
                         for i in range(0, len(grp.vertex_indices[iel_grp])):
                             for j in range(i+1, len(grp.vertex_indices[iel_grp])):
                                 vertex_pair = (min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]), \
                                 max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]))
+                                
+                                # {{{ create midpoint if it doesn't exist already
+
                                 if self.pair_map[vertex_pair].midpoint is None:
                                     self.pair_map[vertex_pair].midpoint = vertices_idx
                                     vertex_pair1 = (min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]), vertices_idx)
@@ -304,26 +389,133 @@ class Refiner(object):
                                     self.pair_map[vertex_pair2] = PairMap(self.pair_map[vertex_pair].ray, not self.pair_map[vertex_pair].d, \
                                         None)
                                     self.vertex_to_ray.append({})
-
+                                    
+                                    # try and collapse the two branches by setting up variables
+                                    # ahead of time
                                     if self.pair_map[vertex_pair].d:
-                                        velements = self.vertex_to_ray[min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])[self.pair_map[vertex_pair].ray]
-                                        print velements
-                                        self.rays[self.pair_map[vertex_pair].ray].insert(Adj(vertices_idx, None), \
+                                        velements = self.vertex_to_ray[min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray].value.elements
+                                        self.rays[self.pair_map[vertex_pair].ray].insert(Adj(vertices_idx, [None, None], velements), \
                                             self.vertex_to_ray[max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray])
                                         self.vertex_to_ray[vertices_idx][self.pair_map[vertex_pair].ray] = \
                                             self.vertex_to_ray[max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray].prev
                                     else:
-                                        velements = self.vertex_to_ray[max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray]
-                                        self.rays[self.pair_map[vertex_pair].ray].insert(Adj(vertices_idx, None), \
+                                        velements = self.vertex_to_ray[max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray].value.elements
+                                        self.rays[self.pair_map[vertex_pair].ray].insert(Adj(vertices_idx, [None, None], velements), \
                                             self.vertex_to_ray[min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray])
                                         self.vertex_to_ray[vertices_idx][self.pair_map[vertex_pair].ray] = \
                                             self.vertex_to_ray[min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray].prev
+                                    
+                                    # compute location of midpoint
                                     for k in range(0, 3):
                                         vertices[k][vertices_idx] = (self.last_mesh.vertices[k][grp.vertex_indices[iel_grp][i]]
                                             + self.last_mesh.vertices[k][grp.vertex_indices[iel_grp][j]]) / 2.0
+                                    
+                                    verts.append(vertices_idx)
                                     vertices_idx += 1
-                        #for i in range(0, 
+                                else:
+                                    verts.append(self.pair_map[vertex_pair].midpoint)
+
+                                # }}}
+                        
+                        # }}}
 
+                        # {{{ fix connectivity
+
+                        # new elements will be nels+0 .. nels+2
+                        # (While those don't exist yet, we generate connectivity for them
+                        # anyhow.)
+
+                        ind = 0
+
+                        # vertex_pairs = [(i,j) for i in range(3) for j in range(i+1, 3)]
+                        for i in range(0, len(grp.vertex_indices[iel_grp])):
+                            for j in range(i+1, len(grp.vertex_indices[iel_grp])):
+                                vertex_pair = (min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]),
+                                max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j]))
+                                if self.pair_map[vertex_pair].d:
+                                    start_node = self.vertex_to_ray[min(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray]
+                                else:
+                                    start_node = self.vertex_to_ray[max(grp.vertex_indices[iel_grp][i], grp.vertex_indices[iel_grp][j])][self.pair_map[vertex_pair].ray]
+                                end_node = self.vertex_to_ray[self.pair_map[vertex_pair].midpoint][self.pair_map[vertex_pair].ray]
+
+                                # hop along ray from start node to end node
+                                while start_node != end_node:
+                                    # start_node.value.elements.index(iel_base+iel_grp)
+                                    if start_node.value.elements[0] == iel_base+iel_grp:
+                                        start_node.value.elements[0] = iel_base+indices_idx+ind
+                                    elif start_node.value.elements[1] == iel_base+iel_grp:
+                                        start_node.value.elements[1] = iel_base+indices_idx+ind
+                                    else:
+                                        assert False
+                                    start_node = start_node.next
+                                ind += 1
+                        ind = 0
+
+                        # }}}
+
+                        #generate new rays
+                        from llist import dllist, dllistnode
+
+                        for i in range(0, len(verts)):
+                            for j in range(i+1, len(verts)):
+                                vertex_pair = (min(verts[i], verts[j]), max(verts[i], verts[j]))
+                                els = []
+                                els.append(iel_base+iel_grp)
+                                els.append(iel_base+indices_idx+ind)
+                                
+                                fr = Adj(min(verts[i], verts[j]), els)
+
+                                # We're creating a new ray, and this is the end node
+                                # of it.
+                                to = Adj(max(verts[i], verts[j]), [None, None])
+
+                                self.rays.append(dllist([fr, to]))
+                                self.pair_map[vertex_pair] = PairMap(len(self.rays) - 1)
+                                self.vertex_to_ray[min(verts[i], verts[j])][len(self.rays)-1]\
+                                    = self.rays[len(self.rays)-1].nodeat(0)
+                                self.vertex_to_ray[max(verts[i], verts[j])][len(self.rays)-1]\
+                                    = self.rays[len(self.rays)-1].nodeat(1)
+
+                        #generate actual elements
+                        #middle element
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]), \
+                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]))
+                        groups[grpn][iel_grp][0] = self.pair_map[vertex_pair].midpoint
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]), \
+                        max(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]))
+                        groups[grpn][iel_grp][1] = self.pair_map[vertex_pair].midpoint
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]), \
+                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]))
+                        groups[grpn][iel_grp][2] = self.pair_map[vertex_pair].midpoint
+                        #element 0
+                        groups[grpn][indices_idx][0] = grp.vertex_indices[iel_grp][0]
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]), \
+                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]))
+                        groups[grpn][indices_idx][1] = self.pair_map[vertex_pair].midpoint
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]), \
+                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]))
+                        groups[grpn][indices_idx][2] = self.pair_map[vertex_pair].midpoint
+                        indices_idx += 1
+                        #element 1
+                        groups[grpn][indices_idx][0] = grp.vertex_indices[iel_grp][1]
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][0]), \
+                        max(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][0]))
+                        groups[grpn][indices_idx][1] = self.pair_map[vertex_pair].midpoint
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]), \
+                        max(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]))
+                        groups[grpn][indices_idx][2] = self.pair_map[vertex_pair].midpoint
+                        indices_idx += 1
+                        #element 2
+                        groups[grpn][indices_idx][0] = grp.vertex_indices[iel_grp][2]
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][0]), \
+                        max(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][0]))
+                        groups[grpn][indices_idx][1] = self.pair_map[vertex_pair].midpoint
+                        vertex_pair = (min(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][1]), \
+                        max(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][1]))
+                        groups[grpn][indices_idx][2] = self.pair_map[vertex_pair].midpoint
+                        indices_idx += 1
+
+                        # }}}
 
                 grpn += 1
 
@@ -331,12 +523,18 @@ class Refiner(object):
         #print vertices
         #print vertex_indices
         from meshmode.mesh.generation import make_group_from_vertices
-        grp = make_group_from_vertices(vertices, vertex_indices, 4)
+        #grp = make_group_from_vertices(vertices, vertex_indices, 4)
+        grp = []
+        grpn = 0
+        for grpn in range(0, len(groups)):
+            grp.append(make_group_from_vertices(vertices, groups[grpn], 4))
+
         from meshmode.mesh import Mesh
         #return Mesh(vertices, [grp], element_connectivity=self.generate_connectivity(len(self.last_mesh.groups[group].vertex_indices) \
         #            + count*3))
         
-        return Mesh(vertices, [grp], element_connectivity=None)
+        self.last_mesh = Mesh(vertices, grp, element_connectivity=self.generate_connectivity(totalnelements, nvertices))
+        return self.last_mesh
         split_faces = {}
 
         ibase = self.get_refine_base_index()
@@ -347,8 +545,10 @@ class Refiner(object):
 
 
 
-    def generate_connectivity(self, nelements):
-        _, nvertices = self.last_mesh.vertices.shape
+    def generate_connectivity(self, nelements, nvertices):
+        # medium-term FIXME: make this an incremental update
+        # rather than build-from-scratch
+
         vertex_to_element = [[] for i in xrange(nvertices)]
 
         for grp in self.last_mesh.groups:
@@ -357,19 +557,34 @@ class Refiner(object):
                 for ivertex in grp.vertex_indices[iel_grp]:
                     vertex_to_element[ivertex].append(iel_base + iel_grp)
 
-        element_to_element = [set() for i in xrange(self.last_mesh.nelements)]
+        element_to_element = [set() for i in xrange(nelements)]
         for grp in self.last_mesh.groups:
             iel_base = grp.element_nr_base
             for iel_grp in xrange(grp.nelements):
                 for ivertex in grp.vertex_indices[iel_grp]:
                     element_to_element[iel_base + iel_grp].update(
                             vertex_to_element[ivertex])
+        
         #print self.ray_elements
+        for ray in self.rays:
+            curnode = ray.first
+            while curnode is not None:
+                if curnode.value.elements[0] is not None:
+                    element_to_element[curnode.value.elements[0]].update(curnode.value.elements)
+                if curnode.value.elements[1] is not None:
+                    element_to_element[curnode.value.elements[1]].update(curnode.value.elements)
+                if curnode.value.velements[0] is not None:
+                    element_to_element[curnode.value.velements[0]].update(curnode.value.velements)
+                if curnode.value.velements[1] is not None:
+                    element_to_element[curnode.value.velements[1]].update(curnode.value.velements)
+                curnode = curnode.next
+        '''
         for i in self.ray_elements:
             for j in i:
                 #print j[0], j[1]
                 element_to_element[j[0]].update(j)
                 element_to_element[j[1]].update(j)
+        '''
         #print element_to_element
         lengths = [len(el_list) for el_list in element_to_element]
         neighbors_starts = np.cumsum(