diff --git a/meshmode/mesh/refinement.py b/meshmode/mesh/refinement.py
index 9758f7d2645d46b6bb7aed39de69fd2873f05a69..117c186a4b6251b4ede4c9504e792e5c6a59632a 100644
--- a/meshmode/mesh/refinement.py
+++ b/meshmode/mesh/refinement.py
@@ -250,13 +250,15 @@ class Refiner(object):
         new_hanging_vertex_element = [
                 [] for i in six.moves.range(nvertices)]
 
-        def update_connectivity(element_rays, to_replace, replace_to):
+        def remove_element_from_connectivity(element_rays, to_remove, seen):
             import six
             for node in element_rays:
-                if to_replace in node.adjacent_elements:
-                    node.adjacent_elements.remove(to_replace)
-                if replace_to not in node.adjacent_elements:
-                    node.adjacent_elements.append(replace_to)
+                leaves = self.get_leaves(node)
+                for leaf in leaves:
+                    if (leaf.left_vertex, leaf.right_vertex) not in seen:
+                        print leaf.left_vertex, leaf.right_vertex, to_remove
+                        leaf.adjacent_elements.remove(to_remove)
+                        seen.append((leaf.left_vertex, leaf.right_vertex))
 
             next_element_rays = []
             for i in six.moves.range(len(element_rays)):
@@ -267,11 +269,39 @@ class Refiner(object):
                         vertex_pair = (min_midpoint, max_midpoint)
                         if vertex_pair in self.pair_map:
                             next_element_rays.append(self.pair_map[vertex_pair])
+                            cur_next_rays = [element_rays[i], element_rays[j], self.pair_map[vertex_pair]]
+                            remove_element_from_connectivity(cur_next_rays, to_remove, seen)
                         else:
                             return
                     else:
                         return
-            update_connectivity(next_element_rays, to_replace, replace_to)
+            remove_element_from_connectivity(next_element_rays, to_remove, seen)
+
+        def add_element_to_connectivity(element_rays, to_add, seen):
+            import six
+            for node in element_rays:
+                leaves = self.get_leaves(node)
+                for leaf in leaves:
+                    if (leaf.left_vertex, leaf.right_vertex) not in seen:
+                        leaf.adjacent_elements.append(to_add)
+                        seen.append((leaf.left_vertex, leaf.right_vertex))
+
+            next_element_rays = []
+            for i in six.moves.range(len(element_rays)):
+                for j in six.moves.range(i+1, len(element_rays)):
+                    if element_rays[i].midpoint is not None and element_rays[j].midpoint is not None:
+                        min_midpoint = min(element_rays[i].midpoint, element_rays[j].midpoint)
+                        max_midpoint = max(element_rays[i].midpoint, element_rays[j].midpoint)
+                        vertex_pair = (min_midpoint, max_midpoint)
+                        if vertex_pair in self.pair_map:
+                            next_element_rays.append(self.pair_map[vertex_pair])
+                            cur_next_rays = [element_rays[i], element_rays[j], self.pair_map[vertex_pair]]
+                            add_element_to_connectivity(cur_next_rays, to_add, seen)
+                        else:
+                            return
+                    else:
+                        return
+            add_element_to_connectivity(next_element_rays, to_add, seen)
 
         def add_hanging_vertex_el(v_index, el):
             assert not (v_index == 37 and el == 48)
@@ -293,9 +323,6 @@ class Refiner(object):
                             cur_node = self.pair_map[(min_index, max_index)]
                             assert ((iel_base+iel_grp) in cur_node.adjacent_elements)
 
-                            
-
-
         for i in six.moves.range(len(self.last_mesh.vertices)):
             for j in six.moves.range(len(self.last_mesh.vertices[i])):
                 vertices[i,j] = self.last_mesh.vertices[i,j]
@@ -500,22 +527,6 @@ class Refiner(object):
                             else:
                                 queue.append(vertex.left)
                                 queue.append(vertex.right)
-                    if (7, 37) in self.pair_map and (7, 38) in self.pair_map and (37, 38) in self.pair_map:
-                        print iel_base+iel_grp
-                        #self.print_rays(9)
-                        leaves = self.get_leaves(self.pair_map[(7, 37)])
-                        for i in leaves:
-                            print (7, 37), i.adjacent_elements
-                        leaves = self.get_leaves(self.pair_map[(7, 38)])
-                        for i in leaves:
-                            print (7, 38), i.adjacent_elements
-                        leaves = self.get_leaves(self.pair_map[(37, 38)])
-                        for i in leaves:
-                            print (37, 38), i.adjacent_elements
-
-                        print self.pair_map[(7, 37)].adjacent_elements, self.pair_map[(7, 37)].left, self.pair_map[(7, 37)].right
-                        print self.pair_map[(7, 38)].adjacent_elements, self.pair_map[(7, 38)].left, self.pair_map[(7, 38)].right
-                        print self.pair_map[(37, 38)].adjacent_elements, self.pair_map[(37, 38)].left, self.pair_map[(37, 38)].right
                     #update connectivity of edges in the center
                     unique_vertex_pairs = [
                         (i,j) for i in range(len(midpoint_vertices)) for j in range(i+1,
@@ -543,6 +554,7 @@ class Refiner(object):
                             #if leaf node
                             if vertex.left is None and vertex.right is None:
                                 node_elements = vertex.adjacent_elements
+                                print iel_base+iel_grp
                                 node_elements.remove(iel_base+iel_grp)
                                 for k in element_indices:
                                     if k == 0:
@@ -618,9 +630,31 @@ class Refiner(object):
                                         node_tuple_to_coord[self.simplex_node_tuples[self.simplex_result[i][j]]]
                     
                     #update tet connectivity
+
+                    #remove from connectivity
                     if len(grp.vertex_indices[0]) == 4:
-                        z_element_rays = []
+                        seen_rays = []
+                        for tup_index, tup in enumerate(self.simplex_result):
+                            three_vertex_tuples = [
+                                    (i, j, k) for i in range(len(tup)) for j in range(i+1, len(tup))
+                                    for k in range(j+1, len(tup))]
+                            for i, j, k in three_vertex_tuples:
+                                vertex_i = node_tuple_to_coord[self.simplex_node_tuples[tup[i]]]
+                                vertex_j = node_tuple_to_coord[self.simplex_node_tuples[tup[j]]]
+                                vertex_k = node_tuple_to_coord[self.simplex_node_tuples[tup[k]]]
+                                element_rays = []
+                                element_rays.append(self.pair_map[(
+                                    min(vertex_i, vertex_j), max(vertex_i, vertex_j))])
+                                element_rays.append(self.pair_map[(
+                                    min(vertex_i, vertex_k), max(vertex_i, vertex_k))])
+                                element_rays.append(self.pair_map[(
+                                    min(vertex_j, vertex_k), max(vertex_j, vertex_k))])
+                                remove_element_from_connectivity(element_rays, iel_base+iel_grp,
+                                        seen_rays)
+
+                        #add to connectivity
                         for tup_index, tup in enumerate(self.simplex_result):
+                            seen_rays = []
                             three_vertex_tuples = [
                                     (i, j, k) for i in range(len(tup)) for j in range(i+1, len(tup))
                                     for k in range(j+1, len(tup))]
@@ -636,17 +670,14 @@ class Refiner(object):
                                 element_rays.append(self.pair_map[(
                                     min(vertex_j, vertex_k), max(vertex_j, vertex_k))])
                                 if tup_index != 0:
-                                    update_connectivity(element_rays, iel_base+iel_grp,
-                                            nelements_in_grp+tup_index-1)
+                                    add_element_to_connectivity(element_rays,
+                                            nelements_in_grp+tup_index-1, seen_rays)
                                 else:
-                                    update_connectivity(element_rays, iel_base+iel_grp,
-                                            -1)
-                                    z_element_rays.append(element_rays)
-                        for el in z_element_rays:
-                            update_connectivity(el, -1, iel_base+iel_grp)
+                                    add_element_to_connectivity(element_rays, iel_base+iel_grp,
+                                            update_seen)
                     nelements_in_grp += len(self.simplex_result)-1
                     #assert ray connectivity
-                    check_adjacent_elements(groups, nelements_in_grp)
+                    #check_adjacent_elements(groups, nelements_in_grp)