diff --git a/meshmode/mesh/refinement.py b/meshmode/mesh/refinement.py
index eb22d5e79adb962c33e106121750a6df9e7d8585..7ea840c11fa778eb8a84e22511246c622179c583 100644
--- a/meshmode/mesh/refinement.py
+++ b/meshmode/mesh/refinement.py
@@ -22,80 +22,31 @@ THE SOFTWARE.
 
 import numpy as np
 
-class VertexRay:
-    def __init__(self, ray, pos):
-        self.ray = ray
-        self.pos = pos
-
-class _SplitFaceRecord(object):
-    """
-    .. attribute:: neighboring_elements
-    .. attribute:: new_vertex_nr
-        integer or None
-    """
-class Adj:
-    """One vertex-associated entry of a ray.
-    .. attribute:: elements
-        A list of numbers of elements adjacent to
-        the edge following :attr:`vertex`
-        along the ray.
-    """
-    def __init__(self, vertex=None, elements=[-1, -1]):
-        self.vertex = vertex
-        self.elements = elements
-    def __str__(self):
-        return 'vertex: ' + str(self.vertex) + ' ' + 'elements: ' +\
-        str(self.elements)
-#map pair of vertices to ray and midpoint
-class PairMap:
-    """Describes a segment of a ray between two vertices.
-    .. attribute:: ray
-        Index of the ray in the *rays* list.
-    .. attribute:: d
+class TreeRayNode:
+    """Describes a ray as a tree, this class represents each node in this tree
+    .. attribute:: left
+        Left child.
+        *None* if ray segment hasn't been split yet.
+    .. attribute:: right
+        Right child.
+        *None* if ray segment hasn't been split yet.
+    .. attribute:: midpoint
+        Vertex index of midpoint of this ray segment.
+        *None* if no midpoint has been assigned yet.
+    .. attribute:: direction
         A :class:`bool` denoting direction in the ray,
         with *True* representing "positive" and *False*
         representing "negative".
-    .. attribute:: midpoint
-        Vertex index of the midpoint of this segment.
-        *None* if no midpoint has been assigned yet.
+    .. attribute:: adjacent_elements
+        List containing elements indices of elements adjacent
+        to this ray segment.
     """
-
-    def __init__(self, ray=None, d = True, midpoint=None):
-        self.ray = ray
-        #direction in ray, True means that second node (bigger index
-        #) is after first in ray
-        self.d = d
-        self.midpoint = midpoint
-    def __str__(self):
-        return 'ray: ' + str(self.ray) + ' d: ' + str(self.d) + \
-                ' midpoint: ' + str(self.midpoint)
-
-'''
-class ElementRefinementTemplate:
-    def __init__(self, dim):
-        from meshmode.mesh.tesselate import tesselatetri
-        self.node_tuples, self.refined_elements = tesselatetri()
-        print self.node_tuples
-        #dicionary that maps a pair of vertices (node arrays) to a special element
-        self.special_elements = {}
-        self.vertices_to_midpoint = {}
-                
-        for i in self.refined_elements:
-            has_two = False
-            has_one = False 
-            for j in i:
-                for k in self.node_tuples[j]:
-                    if k == 1:
-                        has_one = True
-                    if k == 2:
-                        has_two = True
-            #found special element
-            if has_one and not has_two:
-                for j in i:
-                    has_two = False
-                    has_one = False
-                    special_elements 
-'''
+    def __init__(self, direction = True, adjacent_elements = []):
+        self.left = None
+        self.right = None
+        self.midpoint = None
+        self.direction = direction
+        self.adjacent_elements = adjacent_elements
 
 class Refiner(object):
     def __init__(self, mesh):
@@ -103,18 +54,10 @@ class Refiner(object):
         from meshmode.mesh.tesselate  import tesselatetet
         self.tri_node_tuples, self.tri_result = tesselatetet()
         self.last_mesh = mesh
-        # Indices in last_mesh that were split in the last round of
-        # refinement. Only these elements may be split this time
-        # around.
-        self.last_split_elements = None
-
         
         # {{{ initialization
 
-        # a list of dllist instances containing Adj objects
-        self.rays = []
-
-        # mapping: (vertex_1, vertex_2) -> PairMap
+        # mapping: (vertex_1, vertex_2) -> TreeRayNode 
         # where vertex_i represents a vertex number
         #
         # Assumption: vertex_1 < vertex_2
@@ -122,32 +65,17 @@ class Refiner(object):
 
         nvertices = len(mesh.vertices[0])
         
-        # list of dictionaries, with each entry corresponding to
-        # one vertex.
-        # 
-        # Each dictionary maps
-        #   ray number -> dllist node containing a :class:`Adj`,
-        #                 (part of *rays*)
-        self.vertex_to_ray = []
-
-        #np array containing element whose edge lies on corresponding vertex
-        import six
+        #array containing element whose edge lies on corresponding vertex
         self.hanging_vertex_element = []
+        import six
         for i in six.moves.range(nvertices):
             self.hanging_vertex_element.append([])
-#        self.hanging_vertex_element = np.empty([nvertices], dtype=np.int32)
-#        self.hanging_vertex_element.fill(-1)
 
         import six
-        for i in six.moves.range(nvertices):
-            self.vertex_to_ray.append({})
         for grp in mesh.groups:
             iel_base = grp.element_nr_base
             for iel_grp in six.moves.range(grp.nelements):
-                #use six.moves.range
-
                 vert_indices = grp.vertex_indices[iel_grp]
-
                 for i in six.moves.range(len(vert_indices)):
                     for j in six.moves.range(i+1, len(vert_indices)):
                        
@@ -158,18 +86,10 @@ class Refiner(object):
 
                         vertex_pair = (mn_idx, mx_idx)
                         if vertex_pair not in self.pair_map:
-                            els = []
-                            els.append(iel_base+iel_grp)
-                            fr = Adj(mn_idx, els)
-                            to = Adj(mx_idx, [])
-                            self.rays.append(dllist([fr, to]))
-                            self.pair_map[vertex_pair] = PairMap(len(self.rays) - 1)
-                            self.vertex_to_ray[mn_idx][len(self.rays)-1]\
-                                = self.rays[len(self.rays)-1].nodeat(0)
-                            self.vertex_to_ray[mx_idx][len(self.rays)-1]\
-                                = self.rays[len(self.rays)-1].nodeat(1)
-                        else:
-                            self.rays[self.pair_map[vertex_pair].ray].nodeat(0).value.elements.append(iel_base+iel_grp)
+                            self.pair_map[vertex_pair] = TreeRayNode()
+
+                        (self.pair_map[vertex_pair].
+                            adjacent_elements.append(iel_base+iel_grp))
         # }}}
 
         #generate reference tuples
@@ -192,6 +112,9 @@ class Refiner(object):
                     cur = int((i[k] + j[k]) / 2)
                     self.index_to_midpoint_tuple[curind] = self.index_to_midpoint_tuple[curind] + (cur,)
                 curind += 1
+        print "LKJAFLKJASFLJASF"
+        print self.index_to_node_tuple
+        print self.index_to_midpoint_tuple
         #print mesh.groups[0].vertex_indices
         '''
         self.ray_vertices = np.empty([len(mesh.groups[0].vertex_indices) * 
@@ -272,663 +195,27 @@ class Refiner(object):
 
     #refine_flag tells you which elements to split as a numpy array of bools
     def refine(self, refine_flags):
-        """
-        :return: a refined mesh
-        """
-
-        # multi-level mapping:
-        # {
-        #   dimension of intersecting entity (0=vertex,1=edge,2=face,...)
-        #   :
-        #   { frozenset of end vertices : _SplitFaceRecord }
-        # }
-        '''
-        import numpy as np
-        count = 0
-        for i in refine_flags:
-            if i:
-                count += 1
-        
-        le = len(self.last_mesh.vertices[0])
-        vertices = np.empty([len(self.last_mesh.vertices), len(self.last_mesh.groups[0].vertex_indices[0])
-            * count + le])
-        vertex_indices = np.empty([len(self.last_mesh.groups[0].vertex_indices)
-            + count*3, 
-            len(self.last_mesh.groups[0].vertex_indices[0])], dtype=np.int32)
-        indices_it = len(self.last_mesh.groups[0].vertex_indices)        
-        for i in range(0, len(self.last_mesh.vertices)):
-            for j in range(0, len(self.last_mesh.vertices[i])):
-                vertices[i][j] = self.last_mesh.vertices[i][j]
-        for i in range(0, len(self.last_mesh.groups[0].vertex_indices)):
-            for j in range(0, len(self.last_mesh.groups[0].vertex_indices[i])):
-                vertex_indices[i][j] = self.last_mesh.groups[0].vertex_indices[i][j]
-        
-        import itertools
-        for i in range(0, len(refine_flags)):
-            if refine_flags[i]:
-                for subset in itertools.combinations(self.last_mesh.groups[0].vertex_indices[i], 
-                    len(self.last_mesh.groups[0].vertex_indices[i]) - 1):
-                    for j in range(0, len(self.last_mesh.vertices)):
-                        avg = 0
-                        for k in subset:
-                            avg += self.last_mesh.vertices[j][k]
-                        avg /= len(self.last_mesh.vertices)
-                        self.last_mesh.vertices[j][le] = avg
-                        le++
-                vertex_indices[indices_it][0] = self.last_mesh.groups[0].vertex_indices[i][0]
-        '''
-        '''
-        import numpy as np
-        count = 0
-        for i in refine_flags:
-            if i:
-                count += 1
-        #print count
-        #print self.last_mesh.vertices
-        #print vertices
-        if(len(self.last_mesh.groups[0].vertex_indices[0]) == 3):
-            for group in range(0, len(self.last_mesh.groups)):
-                le = len(self.last_mesh.vertices[0])
-                vertices = np.empty([len(self.last_mesh.vertices), 
-                    len(self.last_mesh.groups[group].vertex_indices[0])
-                    * count + le])
-                vertex_indices = np.empty([len(self.last_mesh.groups[group].vertex_indices)
-                    + count*3, 
-                    len(self.last_mesh.groups[group].vertex_indices[0])], dtype=np.int32)
-                indices_i = 0        
-                for i in range(0, len(self.last_mesh.vertices)):
-                    for j in range(0, len(self.last_mesh.vertices[i])):
-                        vertices[i][j] = self.last_mesh.vertices[i][j]
-                #for i in range(0, len(self.last_mesh.groups[group].vertex_indices)):
-                    #for j in range(0, len(self.last_mesh.groups[group].vertex_indices[i])):
-                        #vertex_indices[i][j] = self.last_mesh.groups[group].vertex_indices[i][j]
-                for fl in range(0, len(refine_flags)):
-                    if(refine_flags[fl]):
-                        i = self.last_mesh.groups[group].vertex_indices[fl]
-                        for j in range(0, len(i)):
-                            for k in range(j + 1, len(i)):
-                                for l in range(0, 3):
-                                    #print self.last_mesh.vertices[l][i[j]], ' ', self.last_mesh.vertices[l][i[k]], '=', ((self.last_mesh.vertices[l][i[j]] + self.last_mesh.vertices[l][i[k]]) / 2)
-                                    vertices[l][le]=((self.last_mesh.vertices[l][i[j]] + self.last_mesh.vertices[l][i[k]]) / 2)
-                                le += 1
-                        vertex_indices[indices_i][0] = i[0]
-                        vertex_indices[indices_i][1] = le-3
-                        vertex_indices[indices_i][2] = le-2
-                        indices_i += 1
-                        vertex_indices[indices_i][0] = i[1]
-                        vertex_indices[indices_i][1] = le-1
-                        vertex_indices[indices_i][2] = le-3
-                        indices_i += 1
-                        vertex_indices[indices_i][0] = i[2]
-                        vertex_indices[indices_i][1] = le-2
-                        vertex_indices[indices_i][2] = le-1
-                        indices_i += 1
-                        vertex_indices[indices_i][0] = le-3
-                        vertex_indices[indices_i][1] = le-2
-                        vertex_indices[indices_i][2] = le-1
-                        indices_i += 1
-                    else:
-                        for j in range(0, len(self.last_mesh.groups[group].vertex_indices[fl])):
-                            vertex_indices[indices_i][j] = self.last_mesh.groups[group].vertex_indices[fl][j]
-                        indices_i += 1
-        '''
-        import numpy as np
         import six
-        # if (isinstance(self.last_mesh.groups[0], SimplexElementGroup) and 
-        #           self.last_mesh.groups[0].dim == 2):
-        print 'refining'
-        if(len(self.last_mesh.groups[0].vertex_indices[0]) == 3 or len(self.last_mesh.groups[0].vertex_indices[0]) == 4):
-            groups = []
-            midpoint_already = {}
-            nelements = 0
-            nvertices = len(self.last_mesh.vertices[0])
-            grpn = 0
-            totalnelements=0
-
-            # {{{ create new vertices array and each group's vertex_indices
-            for grp in self.last_mesh.groups:
-                iel_base = grp.element_nr_base
-                nelements = 0
-                #print grp.nelements
-                for iel_grp in six.moves.range(grp.nelements):
-                    nelements += 1
-                    vert_indices = grp.vertex_indices[iel_grp]
-                    if refine_flags[iel_base+iel_grp]:
-                        nelements += len(self.tri_result) - 1
-                        for i in six.moves.range(0, len(vert_indices)):
-                            for j in six.moves.range(i+1, len(vert_indices)):
-                                mn_idx = min(vert_indices[i], vert_indices[j])
-                                mx_idx = max(vert_indices[i], vert_indices[j])
-                                vertex_pair = (mn_idx, mx_idx)
-                                if vertex_pair not in midpoint_already and \
-                                    self.pair_map[vertex_pair].midpoint is None:
-                                    nvertices += 1
-                                    midpoint_already[vertex_pair] = True
-                groups.append(np.empty([nelements,
-                    len(grp.vertex_indices[0])], dtype=np.int32))
-                grpn += 1
-                totalnelements += nelements
-
-            vertices = np.empty([len(self.last_mesh.vertices), nvertices])
-            
-            #create new hanging_vertex_element array
-#            new_hanging_vertex_element = np.empty([nvertices], dtype=np.int32)
-#            new_hanging_vertex_element.fill(-1) 
-            new_hanging_vertex_element = []
-            for i in range(0, nvertices):
-                new_hanging_vertex_element.append([])
-            # }}}
-
-            # {{{ bring over hanging_vertex_element, vertex_indices and vertices info from previous generation
-
-            for i in six.moves.range(0, len(self.last_mesh.vertices)):
-                for j in six.moves.range(0, len(self.last_mesh.vertices[i])):
-                    # always use v[i,j]
-                    vertices[i,j] = self.last_mesh.vertices[i,j]
-                    import copy
-                    if i == 0:
-                        new_hanging_vertex_element[j] = copy.deepcopy(self.hanging_vertex_element[j])
-            grpn = 0
-            for grp in self.last_mesh.groups:
-                for iel_grp in six.moves.range(grp.nelements):
-                    for i in six.moves.range(0, len(grp.vertex_indices[iel_grp])):
-                        groups[grpn][iel_grp][i] = grp.vertex_indices[iel_grp][i]
-                grpn += 1
-            grpn = 0
-
-            # }}}
-
-            vertices_idx = len(self.last_mesh.vertices[0])
-            for grp in self.last_mesh.groups:
-                iel_base = grp.element_nr_base
-                nelements_in_grp = len(grp.vertex_indices)
-                
-                # np.where
-                for iel_grp in six.moves.range(grp.nelements):
-                    #print iel_base+iel_grp, '/', grp.nelements
-                    if refine_flags[iel_base+iel_grp]:
-
-                        # {{{ split element
-
-                        # {{{ go through vertex pairs in element
-
-                        # stores indices of all midpoints for this element
-                        # (in order of vertex pairs in elements)
-                        verts = []
-                        midpoint_tuples = [] 
-                        verts_elements = []
-                        vert_indices = grp.vertex_indices[iel_grp]
-                        #print vert_indices
-                        for i in six.moves.range(0, len(vert_indices)):
-                            for j in six.moves.range(i+1, len(vert_indices)):
-                                verts_elements.append([])
-                                midpoint_tuples.append(tuple([(item1 + item2) / 2 for item1, item2 in zip(self.index_to_node_tuple[i],
-                                    self.index_to_node_tuple[j])]))
-                                if vert_indices[i] < vert_indices[j]:
-                                    mn_idx = vert_indices[i]
-                                    mx_idx = vert_indices[j] 
-                                    imn_idx = i
-                                    imx_idx = j
-                                else:
-                                    mn_idx = vert_indices[j]
-                                    mx_idx = vert_indices[i]
-                                    imn_idx = j
-                                    imx_idx = i
-
-                                vertex_pair = (mn_idx, mx_idx)
-                                
-                                # {{{ create midpoint if it doesn't exist already
-
-                                cur_pmap = self.pair_map[vertex_pair]
-                                if cur_pmap.midpoint is None:
-                                    self.pair_map[vertex_pair].midpoint = vertices_idx
-                                    vertex_pair1 = (mn_idx, vertices_idx)
-                                    vertex_pair2 = (mx_idx, vertices_idx)
-                                    self.pair_map[vertex_pair1] =\
-                                        PairMap(cur_pmap.ray, cur_pmap.d, None)
-                                    self.pair_map[vertex_pair2] =\
-                                        PairMap(cur_pmap.ray, not cur_pmap.d, None)
-                                    self.vertex_to_ray.append({})
-                                    
-                                    # FIXME: Check where the new Adj.elements
-                                    # gets populated.
-
-                                    # try and collapse the two branches by setting up variables
-                                    # ahead of time
-                                    import copy
-                                    if self.pair_map[vertex_pair].d:
-                                        elements = self.vertex_to_ray[mn_idx][cur_pmap.ray].value.elements
-                                        self.rays[cur_pmap.ray].insert(Adj(vertices_idx, copy.deepcopy(elements)),
-                                            self.vertex_to_ray[mx_idx][cur_pmap.ray])
-
-                                        #stupid bug: don't use i when already in use
-                                        for el in elements:
-                                            if el != (iel_base + iel_grp):
-                                                verts_elements[len(verts_elements)-1].append(el)
-                                                new_hanging_vertex_element[vertices_idx].append(el)
-                                        self.vertex_to_ray[vertices_idx][cur_pmap.ray] = \
-                                            self.vertex_to_ray[mx_idx][cur_pmap.ray].prev
-                                    else:
-                                        elements = self.vertex_to_ray[mx_idx][cur_pmap.ray].value.elements
-                                        self.rays[cur_pmap.ray].insert(Adj(vertices_idx, copy.deepcopy(elements)),
-                                            self.vertex_to_ray[mn_idx][cur_pmap.ray])
-                                        for el in elements:
-                                            if el != (iel_base + iel_grp):
-                                                verts_elements[len(verts_elements)-1].append(el)
-                                                new_hanging_vertex_element[vertices_idx].append(el)
-
-                                        self.vertex_to_ray[vertices_idx][cur_pmap.ray] = \
-                                            self.vertex_to_ray[mn_idx][cur_pmap.ray].prev
-                                    #print len(vert_indices)
-                                    # compute location of midpoint
-                                    for k in six.moves.range(len(self.last_mesh.vertices)):
-                                        vertices[k,vertices_idx] =\
-                                            (self.last_mesh.vertices[k,vert_indices[i]] +
-                                                    self.last_mesh.vertices[k,vert_indices[j]]) / 2.0
-                                    
-                                    verts.append(vertices_idx)
-                                    vertices_idx += 1
-                                else:
-                                    cur_midpoint = self.pair_map[vertex_pair].midpoint
-                                    elements = self.vertex_to_ray[cur_midpoint][cur_pmap.ray].prev.value.elements
-                                    for el in elements:
-                                        if el != (iel_base + iel_grp) and el not in verts_elements[len(verts_elements)-1]:
-                                            verts_elements[len(verts_elements)-1].append(el)
-                                    elements = self.vertex_to_ray[cur_midpoint][cur_pmap.ray].value.elements
-                                    for el in elements:
-                                        if el != (iel_base + iel_grp) and el not in verts_elements[len(verts_elements)-1]:
-                                            verts_elements[len(verts_elements)-1].append(el)
-
-                                    for el in new_hanging_vertex_element[cur_midpoint]:
-                                        if el != (iel_base + iel_grp) and el not in verts_elements[len(verts_elements)-1]:
-                                            verts_elements[len(verts_elements)-1].append(el)
-                                    verts.append(cur_midpoint)
-                                    #new_hanging_vertex_element[cur_midpoint] = []
-                                #print new_hanging_vertex_element
-                                # }}}
-                        
-                        # }}}
-
-                        # {{{ fix connectivity
-
-                        # new elements will be nels+0 .. nels+2 ...
-                        # (While those don't exist yet, we generate connectivity for them
-                        # anyhow.)
-
-                        unique_vertex_pairs = [
-                                (i,j) for i in range(len(vert_indices)) for j in range(i+1, len(vert_indices))]
-                        midpoint_ind = 0
-                        for i, j in unique_vertex_pairs:
-                            mn_idx = min(vert_indices[i], vert_indices[j]) 
-                            mx_idx = max(vert_indices[i], vert_indices[j])
-                            element_indices_1 = []
-                            element_indices_2 = []
-                            for k_ind, k in enumerate(self.tri_result):
-                                ituple_ind = self.tri_node_tuples.index(self.index_to_node_tuple[i])
-                                jtuple_ind = self.tri_node_tuples.index(self.index_to_node_tuple[j])
-                                midpointtuple_ind = self.tri_node_tuples.index(self.index_to_midpoint_tuple[midpoint_ind])
-                                if ituple_ind in k and\
-                                    midpointtuple_ind in k:
-                                        element_indices_1.append(k_ind)
-                                if jtuple_ind in k and\
-                                    midpointtuple_ind in k:
-                                        element_indices_2.append(k_ind)
-                            midpoint_ind += 1
-                            #print "ELEMENTIDX1: ", element_indices_1
-                            #print "ELEMENTIDX2: ", element_indices_2
-                            if mn_idx == vert_indices[i]:
-                                min_element_index = element_indices_1
-                                max_element_index = element_indices_2
-                            else:
-                                min_element_index = element_indices_2
-                                max_element_index = element_indices_1
-                            #print "ELEMENTIDX: ", min_element_index, max_element_index
-                            vertex_pair = (mn_idx, mx_idx)
-                            cur_pmap = self.pair_map[vertex_pair]
-                            if cur_pmap.d:
-                                start_node =\
-                                self.vertex_to_ray[mn_idx][cur_pmap.ray]
-                                end_node = self.vertex_to_ray[mx_idx][cur_pmap.ray]
-                                first_element_index = min_element_index
-                                second_element_index = max_element_index
-                            else:
-                                start_node =\
-                                self.vertex_to_ray[mx_idx][cur_pmap.ray]
-                                end_node = self.vertex_to_ray[mn_idx][cur_pmap.ray]
-                                first_element_index = max_element_index
-                                second_element_index = min_element_index
-                            midpoint_node=\
-                            self.vertex_to_ray[cur_pmap.midpoint][cur_pmap.ray]
-                            # hop along ray from start node to midpoint node
-                            #print "Nodes: ", start_node.value, midpoint_node.value
-                            while start_node != midpoint_node:
-                                # replace old (big) element with new
-                                # (refined) element
-                                print "BEFOREEEalsdfjlkjasdfl:", mn_idx, mx_idx, start_node.value
-                                node_els = start_node.value.elements
-                                #print "OLD NODE ELS: ", node_els
-                                #print node_els
-                                node_els.remove(iel_base+iel_grp)
-                                for k in first_element_index:
-                                    if k == 0:
-                                        node_els.append(iel_base+iel_grp)
-                                    else:
-                                        node_els.append(iel_base+nelements_in_grp+k - 1)
-                                print "AFTERLKjlajsdfljkadf:", mn_idx, mx_idx, start_node.value
-                                #print "NEW_NODE_ELS: ", node_els
-                                #node_els[iold_el] = iel_base+nelements_in_grp+first_element_index
-                                #print "HANGING: ", new_hanging_vertex_element[start_node.value.vertex]
-                                if new_hanging_vertex_element[start_node.value.vertex] and \
-                                    new_hanging_vertex_element[start_node.value.vertex].count(
-                                        iel_base+iel_grp):
-                                        '''
-                                        to_replace_index = new_hanging_vertex_element[start_node.value.vertex].\
-                                                index(iel_base+iel_grp)
-                                        new_hanging_vertex_element[start_node.value.vertex][to_replace_index] =\
-                                                iel_base+nelements_in_grp+first_element_index
-                                        '''
-                                        new_hanging_vertex_element[start_node.value.vertex].remove(iel_base+iel_grp)
-                                        for k in first_element_index:
-                                            if k == 0:
-                                                new_hanging_vertex_element[start_node.value.vertex].append(iel_base+iel_grp)
-                                            else:
-                                                new_hanging_vertex_element[start_node.value.vertex].append(iel_base+nelements_in_grp+k - 1)
-                                start_node = start_node.next
-                            # hop along ray from midpoint node to end node
-                            while start_node != end_node:
-                                #replace old (big) element with new
-                                # (refined element
-                                print "BEFOREEEalsdfjlkjasdfl:", mn_idx, mx_idx, start_node.value
-                                node_els = start_node.value.elements
-                                #iold_el = node_els.index(iel_base+iel_grp)
-                                #node_els[iold_el] = iel_base+nelements_in_grp+second_element_index
-                                node_els.remove(iel_base+iel_grp)
-                                for k in second_element_index:
-                                    if k == 0:
-                                        node_els.append(iel_base+iel_grp)
-                                    else:
-                                        node_els.append(iel_base+nelements_in_grp+k-1)
-                                print "AFTERLKjlajsdfljkadf:", mn_idx, mx_idx, start_node.value
-                                if new_hanging_vertex_element[start_node.value.vertex] and \
-                                    new_hanging_vertex_element[start_node.value.vertex].count(
-                                        iel_base+iel_grp):
-                                        '''
-                                        to_replace_index = new_hanging_vertex_element[start_node.value.vertex].\
-                                                index(iel_base+iel_grp)
-                                        new_hanging_vertex_element[start_node.value.vertex][to_replace_index] =\
-                                                iel_base+nelements_in_grp+second_element_index
-                                        '''
-                                        new_hanging_vertex_element[start_node.value.vertex].remove(iel_base+iel_grp)
-                                        for k in second_element_index:
-                                            if k == 0:
-                                                new_hanging_vertex_element[start_node.value.vertex].append(iel_base+iel_grp)
-                                            else:
-                                                new_hanging_vertex_element[start_node.value.vertex].append(iel_base+nelements_in_grp+k-1)
-                                start_node = start_node.next
-
-                        unique_vertex_pairs = [
-                                (i,j) for i in range(len(verts)) for j in range(i+1, len(verts))]
-                        midpoint_ind = 0
-                        for i, j in unique_vertex_pairs:
-                            mn_idx = min(verts[i], verts[j]) 
-                            mx_idx = max(verts[i], verts[j])
-                            vertex_pair = (mn_idx, mx_idx)
-                            print "ASFLKJALFKJASFLKJAFVERTEXPAIR:", vertex_pair
-                            if not vertex_pair in self.pair_map:
-                                continue
-                            print "DOING STUASASLKFJLASKFJ"
-                            element_indices = []
-                            for k_ind, k in enumerate(self.tri_result):
-                                ituple_ind = self.tri_node_tuples.index(self.index_to_midpoint_tuple[i])
-                                jtuple_ind = self.tri_node_tuples.index(self.index_to_midpoint_tuple[j])
-                                if ituple_ind in k and\
-                                    jtuple_ind in k:
-                                        element_indices.append(k_ind)
-                            #print "ELEMENTIDX1: ", element_indices_1
-                            #print "ELEMENTIDX2: ", element_indices_2
-                            cur_pmap = self.pair_map[vertex_pair]
-                            start_node =\
-                            self.vertex_to_ray[mn_idx][cur_pmap.ray]
-                            end_node = self.vertex_to_ray[mx_idx][cur_pmap.ray]
-                            while start_node != end_node:
-                                # replace old (big) element with new
-                                # (refined) element
-                                node_els = start_node.value.elements
-                                #print "OLD NODE ELS: ", node_els
-                                #print node_els
-                                node_els.remove(iel_base+iel_grp)
-                                for k in element_indices:
-                                    if k == 0:
-                                        node_els.append(iel_base+iel_grp)
-                                    else:
-                                        node_els.append(iel_base+nelements_in_grp+k - 1)
-                                #print "NEW_NODE_ELS: ", node_els
-                                #node_els[iold_el] = iel_base+nelements_in_grp+first_element_index
-                                #print "HANGING: ", new_hanging_vertex_element[start_node.value.vertex]
-                                if new_hanging_vertex_element[start_node.value.vertex] and \
-                                    new_hanging_vertex_element[start_node.value.vertex].count(
-                                        iel_base+iel_grp):
-                                        '''
-                                        to_replace_index = new_hanging_vertex_element[start_node.value.vertex].\
-                                                index(iel_base+iel_grp)
-                                        new_hanging_vertex_element[start_node.value.vertex][to_replace_index] =\
-                                                iel_base+nelements_in_grp+first_element_index
-                                        '''
-                                        new_hanging_vertex_element[start_node.value.vertex].remove(iel_base+iel_grp)
-                                        for k in element_indices:
-                                            if k == 0:
-                                                new_hanging_vertex_element[start_node.value.vertex].append(iel_base+iel_grp)
-                                            else:
-                                                new_hanging_vertex_element[start_node.value.vertex].append(iel_base+nelements_in_grp+k - 1)
-                                start_node = start_node.next
-                        # }}}
-                        #TODO: Update existing hanging nodes and elements for rays that may have already been generated by different element
-                        #generate new rays
-                        from llist import dllist, dllistnode
-                        ind = 0
-                        for i in six.moves.range(0, len(verts)):
-                            for j in six.moves.range(i+1, len(verts)):
-                                mn_vert = min(verts[i], verts[j])
-                                mx_vert = max(verts[i], verts[j])
-                                vertex_pair = (mn_vert, mx_vert)
-                                if vertex_pair in self.pair_map:
-                                    continue
-                                els = []
-                                #common_elements = list(set(verts_elements[i]).intersection(
-                                #    verts_elements[j]))
-                                common_elements = list(set(new_hanging_vertex_element[mn_vert]).
-                                        intersection(new_hanging_vertex_element[mx_vert]))
-                                for cel in common_elements:
-                                    els.append(cel)
-                                vert1ind = self.tri_node_tuples.index(self.index_to_midpoint_tuple[i])
-                                vert2ind = self.tri_node_tuples.index(self.index_to_midpoint_tuple[j])
-                                for kind, k in enumerate(self.tri_result):
-                                    if vert1ind in k and vert2ind in k:
-                                        if kind == 0:
-                                            els.append(iel_base+iel_grp)
-                                        else:
-                                            els.append(iel_base + nelements_in_grp + kind - 1)
-                                #print "ELS: ", els
-                                #els.append(iel_base+iel_grp)
-                                #els.append(iel_base+nelements_in_grp+ind)
-                                
-                                fr = Adj(mn_vert, els)
-
-                                # We're creating a new ray, and this is the end node
-                                # of it.
-                                to = Adj(mx_vert, [])
-
-                                self.rays.append(dllist([fr, to]))
-                                self.pair_map[vertex_pair] = PairMap(len(self.rays) - 1)
-                                self.vertex_to_ray[mn_vert][len(self.rays)-1]\
-                                    = self.rays[len(self.rays)-1].nodeat(0)
-                                self.vertex_to_ray[mx_vert][len(self.rays)-1]\
-                                    = self.rays[len(self.rays)-1].nodeat(1)
-                                ind += 1
-                        ind = 0
-                        #map vertex indices to coordinates
-                        node_tuple_to_coord = {}
-                        for node_ind, node_tuple in enumerate(self.index_to_node_tuple):
-                            node_tuple_to_coord[node_tuple] = grp.vertex_indices[iel_grp][node_ind]
-                        for midpoint_ind, midpoint_tuple in enumerate(self.index_to_midpoint_tuple):
-                            node_tuple_to_coord[midpoint_tuple] = verts[midpoint_ind]
-                        o_nelements_in_grp = nelements_in_grp
-                        for i in six.moves.range(0, len(self.tri_result)):
-                            for j in six.moves.range(0, len(self.tri_result[i])):
-                                if i == 0:
-                                    groups[grpn][iel_grp][j] = \
-                                            node_tuple_to_coord[self.tri_node_tuples[self.tri_result[i][j]]]
-                                else:
-                                    #print nelements_in_grp
-                                    groups[grpn][nelements_in_grp][j] = \
-                                        node_tuple_to_coord[self.tri_node_tuples[self.tri_result[i][j]]]
-                            if i != 0:
-                                nelements_in_grp += 1
-
-                        '''
-                        el_ind = 19 
-                        if iel_base + iel_grp == el_ind:
-                            print "LJKASFLKJASFASKL:JFASFLA:SFJ"
-                            for i in six.moves.range(len(groups[grpn][el_ind])):
-                                for j in six.moves.range(i+1, len(groups[grpn][el_ind])):
-                                    mn = min(groups[grpn][el_ind][i], groups[grpn][el_ind][j])
-                                    mx = max(groups[grpn][el_ind][j], groups[grpn][el_ind][i])
-                                    d = self.pair_map[(mn, mx)].d
-                                    ray = self.pair_map[(mn, mx)].ray
-                                    if d:
-                                        print self.vertex_to_ray[mn][ray]
-                                    else:
-                                        print self.vertex_to_ray[mx][ray]
-                        '''
-                        '''
-                        print "REPAIRING:", iel_base+iel_grp
-                        for i in six.moves.range(len(groups[grpn][iel_base+iel_grp])):
-                            for j in six.moves.range(i+1, len(groups[grpn][iel_base+iel_grp])):
-                                mn = min(groups[grpn][iel_base+iel_grp][i], groups[grpn][iel_base+iel_grp][j])
-                                mx = max(groups[grpn][iel_base + iel_grp][i], groups[grpn][iel_base+iel_grp][j])
-                                d = self.pair_map[(mn, mx)].d
-                                ray = self.pair_map[(mn, mx)].ray
-                                if d:
-                                    if iel_base+iel_grp not in self.vertex_to_ray[mn][ray].value.elements:
-                                        self.vertex_to_ray[mn][ray].value.elements.append(iel_base+iel_grp)
-                                else:
-                                    if iel_base + iel_grp not in self.vertex_to_ray[mx][ray].value.elements:
-                                        self.vertex_to_ray[mx][ray].value.elements.append(iel_base+iel_grp)
-                        for elem in six.moves.range(nelements_in_grp):
-                            #print "REPAIRING:", elem
-                            for i in six.moves.range(len(groups[grpn][elem])):
-                                for j in six.moves.range(i+1, len(groups[grpn][elem])):
-                                    mn = min(groups[grpn][elem][i], groups[grpn][elem][j])
-                                    mx = max(groups[grpn][elem][i], groups[grpn][elem][j])
-                                    d = self.pair_map[(mn, mx)].d
-                                    ray = self.pair_map[(mn, mx)].ray
-                                    if d:
-                                        if elem not in self.vertex_to_ray[mn][ray].value.elements:
-                                            print "FAILING!!!!!"
-                                            self.vertex_to_ray[mn][ray].value.elements.append(elem)
-                                    else:
-                                        if elem not in self.vertex_to_ray[mx][ray].value.elements:
-                                            print "FAILING!!!!!"
-                                            self.vertex_to_ray[mx][ray].value.elements.append(elem)
-                        '''
-
-                        #print nelements_in_grp
-                        #print self.tri_node_tuples
-                        #print self.tri_result
-                        '''
-                        #map vertex indices to coordinates
-                        node_tuple_to_coord = {}
-                        node_tuple_to_coord[(0, 0)] = grp.vertex_indices[iel_grp][0]
-                        node_tuple_to_coord[(2, 0)] = grp.vertex_indices[iel_grp][1]
-                        node_tuple_to_coord[(0, 2)] = grp.vertex_indices[iel_grp][2]
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]), \
-                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]))
-                        node_tuple_to_coord[(1, 0)] = self.pair_map[vertex_pair].midpoint 
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]), \
-                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]))
-                        node_tuple_to_coord[(0, 1)] = self.pair_map[vertex_pair].midpoint
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]), \
-                        max(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]))
-                        node_tuple_to_coord[(1, 1)] = self.pair_map[vertex_pair].midpoint
-                        #generate actual elements
-                        #middle element
-                        for i in six.moves.range(0, len(self.tri_result[1])):
-                            groups[grpn][iel_grp][i] = \
-                            node_tuple_to_coord[self.tri_node_tuples[self.tri_result[1][i]]]
-                        for i in six.moves.range(0, 4):
-                            if i == 1:
-                                continue
-                            for j in six.moves.range(0, len(self.tri_result[i])):
-                                groups[grpn][nelements_in_grp][j] = \
-                                        node_tuple_to_coord[self.tri_node_tuples[self.tri_result[i][j]]]
-                            nelements_in_grp += 1
-                        '''
-                        '''
-                        #middle element
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]), \
-                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]))
-                        groups[grpn][iel_grp][0] = self.pair_map[vertex_pair].midpoint
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]), \
-                        max(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]))
-                        groups[grpn][iel_grp][1] = self.pair_map[vertex_pair].midpoint
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]), \
-                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]))
-                        groups[grpn][iel_grp][2] = self.pair_map[vertex_pair].midpoint
-                        #element 0
-                        groups[grpn][nelements_in_grp][0] = grp.vertex_indices[iel_grp][0]
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]), \
-                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][1]))
-                        groups[grpn][nelements_in_grp][1] = self.pair_map[vertex_pair].midpoint
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]), \
-                        max(grp.vertex_indices[iel_grp][0], grp.vertex_indices[iel_grp][2]))
-                        groups[grpn][nelements_in_grp][2] = self.pair_map[vertex_pair].midpoint
-                        nelements_in_grp += 1
-                        #element 1
-                        groups[grpn][nelements_in_grp][0] = grp.vertex_indices[iel_grp][1]
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][0]), \
-                        max(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][0]))
-                        groups[grpn][nelements_in_grp][1] = self.pair_map[vertex_pair].midpoint
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]), \
-                        max(grp.vertex_indices[iel_grp][1], grp.vertex_indices[iel_grp][2]))
-                        groups[grpn][nelements_in_grp][2] = self.pair_map[vertex_pair].midpoint
-                        nelements_in_grp += 1
-                        #element 2
-                        groups[grpn][nelements_in_grp][0] = grp.vertex_indices[iel_grp][2]
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][0]), \
-                        max(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][0]))
-                        groups[grpn][nelements_in_grp][1] = self.pair_map[vertex_pair].midpoint
-                        vertex_pair = (min(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][1]), \
-                        max(grp.vertex_indices[iel_grp][2], grp.vertex_indices[iel_grp][1]))
-                        groups[grpn][nelements_in_grp][2] = self.pair_map[vertex_pair].midpoint
-                        nelements_in_grp += 1
-                        '''
-                        # }}}
-
-                grpn += 1
-
-        self.hanging_vertex_element = new_hanging_vertex_element
-        #print vertices
-        #print vertex_indices
-        from meshmode.mesh.generation import make_group_from_vertices
-        #grp = make_group_from_vertices(vertices, vertex_indices, 4)
-        grp = []
-        grpn = 0
-        for grpn in range(0, len(groups)):
-            grp.append(make_group_from_vertices(vertices, groups[grpn], 4))
-
-        from meshmode.mesh import Mesh
-        #return Mesh(vertices, [grp], element_connectivity=self.generate_connectivity(len(self.last_mesh.groups[group].vertex_indices) \
-        #            + count*3))
-        
-        self.last_mesh = Mesh(vertices, grp, element_connectivity=self.generate_connectivity(totalnelements, nvertices, groups))
-        return self.last_mesh
-        split_faces = {}
-
-        ibase = self.get_refine_base_index()
-        affected_group_indices = set()
-
+        nvertices = self.last_mesh.nvertices
         for grp in self.last_mesh.groups:
-            iel_base
+            iel_base = grp.element_nr_base
+            for iel_grp in six.moves.range(grp.nelements):
+                if refine_flags[iel_base+iel_grp]:
+                    vertex_indices = grp.vertex_indices[iel_grp]
+                    for i in six.moves.range(len(vertex_indices)):
+                        for j in six.moves.range(i+1, len(vertex_indices)):
+                            min_index = min(vertex_indices[i], vertex_indices[j])
+                            max_index = max(vertex_indices[i], vertex_indices[j])
+                            cur_node = self.pair_map[(min_index, max_index)]
+                            if cur_node.midpoint is None:
+                                cur_node.midpoint = nvertices
+                                cur_node.left = TreeRayNode(cur_node.direction, cur_node.adjacent_elements)
+                                cur_node.right = TreeRayNode(not cur_node.direction, cur_node.adjacent_elements)
+                                left_index = cur_node.left.adjacent_elements.index(iel_base+iel_grp)
+                                #right_index = cur_node
+                                #cur_node.left.adjacent_elements[left_index] = 
+                                nvertices += 1
+
 
     def print_rays(self, ind):
         import six