From 7866d80a9eec955fc233d8948e0c8dc54dca5828 Mon Sep 17 00:00:00 2001
From: Shivam Gupta <sgupta72@illinois.edu>
Date: Mon, 9 May 2016 23:17:19 -0500
Subject: [PATCH] Trying something

---
 meshmode/mesh/refinement.py | 800 +++++++++++++++++++++---------------
 1 file changed, 472 insertions(+), 328 deletions(-)

diff --git a/meshmode/mesh/refinement.py b/meshmode/mesh/refinement.py
index 117c186..1b8d008 100644
--- a/meshmode/mesh/refinement.py
+++ b/meshmode/mesh/refinement.py
@@ -55,8 +55,14 @@ class Refiner(object):
     def __init__(self, mesh):
         #print 'herlkjjlkjasdf'
         from llist import dllist, dllistnode
-        from meshmode.mesh.tesselate  import tesselatetet
-        self.simplex_node_tuples, self.simplex_result = tesselatetet()
+        from meshmode.mesh.tesselate  import tesselatetet, tesselatetri
+        tri_node_tuples, tri_result = tesselatetri()
+        tet_node_tuples, tet_result = tesselatetet()
+        print tri_result, tet_result
+        self.simplex_node_tuples = [None, None, tri_node_tuples, tet_node_tuples]
+        self.simplex_result = [None, None, tri_result, tet_result]
+        #print tri_node_tuples, tri_result
+        #self.simplex_node_tuples, self.simplex_result = tesselatetet()
         self.last_mesh = mesh
         
         # {{{ initialization
@@ -98,26 +104,34 @@ class Refiner(object):
                                 adjacent_elements.append(iel_base+iel_grp))
         # }}}
 
+        print vert_indices
         #generate reference tuples
-        self.index_to_node_tuple = [()] * (len(vert_indices))
-        for i in six.moves.range(0, len(vert_indices)-1):
-            self.index_to_node_tuple[0] = self.index_to_node_tuple[0] + (0,)
-        for i in six.moves.range(1, len(vert_indices)):
-            for j in six.moves.range(1, len(vert_indices)):
-                if i == j:
-                    self.index_to_node_tuple[i] = self.index_to_node_tuple[i] + (2,)
-                else:
-                    self.index_to_node_tuple[i] = self.index_to_node_tuple[i] + (0,)
-        self.index_to_midpoint_tuple = [()] * (int((len(vert_indices) * (len(vert_indices) - 1)) / 2))
-        curind = 0
-        for ind1 in six.moves.range(0, len(self.index_to_node_tuple)):
-            for ind2 in six.moves.range(ind1+1, len(self.index_to_node_tuple)):
-                i = self.index_to_node_tuple[ind1]
-                j = self.index_to_node_tuple[ind2]
-                for k in six.moves.range(0, len(vert_indices)-1):
-                    cur = int((i[k] + j[k]) / 2)
-                    self.index_to_midpoint_tuple[curind] = self.index_to_midpoint_tuple[curind] + (cur,)
-                curind += 1
+        self.index_to_node_tuple = []
+        self.index_to_midpoint_tuple = []
+        for d in six.moves.range(len(vert_indices)):
+            dim = d + 1
+            cur_index_to_node_tuple = [()] * dim
+            for i in six.moves.range(0, dim-1):
+                cur_index_to_node_tuple[0] = cur_index_to_node_tuple[0] + (0,)
+            for i in six.moves.range(1, dim):
+                for j in six.moves.range(1, dim):
+                    if i == j:
+                        cur_index_to_node_tuple[i] = cur_index_to_node_tuple[i] + (2,)
+                    else:
+                        cur_index_to_node_tuple[i] = cur_index_to_node_tuple[i] + (0,)
+            cur_index_to_midpoint_tuple = [()] * (int((dim * (dim - 1)) / 2))
+            curind = 0
+            for ind1 in six.moves.range(0, len(cur_index_to_node_tuple)):
+                for ind2 in six.moves.range(ind1+1, len(cur_index_to_node_tuple)):
+                    i = cur_index_to_node_tuple[ind1]
+                    j = cur_index_to_node_tuple[ind2]
+                    print i, j
+                    for k in six.moves.range(0, dim-1):
+                        cur = int((i[k] + j[k]) / 2)
+                        cur_index_to_midpoint_tuple[curind] = cur_index_to_midpoint_tuple[curind] + (cur,)
+                    curind += 1
+            self.index_to_node_tuple.append(cur_index_to_node_tuple)
+            self.index_to_midpoint_tuple.append(cur_index_to_midpoint_tuple)
         '''
         self.ray_vertices = np.empty([len(mesh.groups[0].vertex_indices) * 
             len(mesh.groups[0].vertex_indices[0]) * (len(mesh.groups[0].vertex_indices[0]) - 1) / 2, 2], 
@@ -209,6 +223,37 @@ class Refiner(object):
                 queue.append(vertex.right)
         return res
 
+    def get_subtree(self, cur_node):
+        queue = [cur_node]
+        res = []
+        while queue:
+            vertex = queue.pop(0)
+            res.append(vertex)
+            if not (vertex.left is None and vertex.right is None):
+                queue.append(vertex.left)
+                queue.append(vertex.right)
+        return res
+    
+    def remove_from_subtree(self, cur_node, new_hanging_vertex_elements, to_remove):
+        subtree = self.get_subtree(cur_node)
+        for node in subtree:
+            if to_remove in node.adjacent_elements:
+                node.adjacent_elements.remove(to_remove)
+            if to_remove in new_hanging_vertex_elements[node.left_vertex]:
+                new_hanging_vertex_elements[node.left_vertex].remove(to_remove)
+            if to_remove in new_hanging_vertex_elements[node.right_vertex]:
+                new_hanging_vertex_elements[node.right_vertex].remove(to_remove)
+
+    def add_to_subtree(self, cur_node, new_hanging_vertex_elements, to_add):
+        subtree = self.get_subtree(cur_node)
+        for node in subtree:
+            if to_add not in node.adjacent_elements:
+                node.adjacent_elements.append(to_add)
+            if to_add not in new_hanging_vertex_elements[node.left_vertex]:
+                new_hanging_vertex_elements[node.left_vertex].append(to_add)
+            if to_add not in new_hanging_vertex_elements[node.right_vertex]:
+                new_hanging_vertex_elements[node.right_vertex].append(to_add)
+
     #refine_flag tells you which elements to split as a numpy array of bools
     def refine(self, refine_flags):
         import six
@@ -231,7 +276,8 @@ class Refiner(object):
                 nelements += 1
                 vertex_indices = grp.vertex_indices[iel_grp]
                 if refine_flags[iel_base+iel_grp]:
-                    nelements += len(self.simplex_result) - 1
+                    cur_dim = len(grp.vertex_indices[iel_grp])-1
+                    nelements += len(self.simplex_result[cur_dim]) - 1
                     for i in six.moves.range(len(vertex_indices)):
                         for j in six.moves.range(i+1, len(vertex_indices)):
                             i_index = vertex_indices[i]
@@ -250,58 +296,143 @@ class Refiner(object):
         new_hanging_vertex_element = [
                 [] for i in six.moves.range(nvertices)]
 
-        def remove_element_from_connectivity(element_rays, to_remove, seen):
+        def remove_element_from_connectivity(vertices, new_hanging_vertex_elements, to_remove):
+            #print vertices
             import six
-            for node in element_rays:
-                leaves = self.get_leaves(node)
-                for leaf in leaves:
-                    if (leaf.left_vertex, leaf.right_vertex) not in seen:
-                        print leaf.left_vertex, leaf.right_vertex, to_remove
-                        leaf.adjacent_elements.remove(to_remove)
-                        seen.append((leaf.left_vertex, leaf.right_vertex))
-
-            next_element_rays = []
-            for i in six.moves.range(len(element_rays)):
-                for j in six.moves.range(i+1, len(element_rays)):
-                    if element_rays[i].midpoint is not None and element_rays[j].midpoint is not None:
-                        min_midpoint = min(element_rays[i].midpoint, element_rays[j].midpoint)
-                        max_midpoint = max(element_rays[i].midpoint, element_rays[j].midpoint)
-                        vertex_pair = (min_midpoint, max_midpoint)
-                        if vertex_pair in self.pair_map:
-                            next_element_rays.append(self.pair_map[vertex_pair])
-                            cur_next_rays = [element_rays[i], element_rays[j], self.pair_map[vertex_pair]]
-                            remove_element_from_connectivity(cur_next_rays, to_remove, seen)
-                        else:
-                            return
+            import itertools
+            if len(vertices) == 2:
+                min_vertex = min(vertices[0], vertices[1])
+                max_vertex = max(vertices[0], vertices[1])
+                ray = self.pair_map[(min_vertex, max_vertex)]
+                self.remove_from_subtree(ray, new_hanging_vertex_elements, to_remove)
+                return
+
+            cur_dim = len(vertices)-1
+            element_rays = []
+            midpoints = []
+            split_possible = True
+            for i in six.moves.range(len(vertices)):
+                for j in six.moves.range(i+1, len(vertices)):
+                    min_vertex = min(vertices[i], vertices[j])
+                    max_vertex = max(vertices[i], vertices[j])
+                    element_rays.append(self.pair_map[(min_vertex, max_vertex)])
+                    if element_rays[len(element_rays)-1].midpoint is not None:
+                        midpoints.append(element_rays[len(element_rays)-1].midpoint)
                     else:
-                        return
-            remove_element_from_connectivity(next_element_rays, to_remove, seen)
+                        split_possible = False
 
-        def add_element_to_connectivity(element_rays, to_add, seen):
-            import six
             for node in element_rays:
-                leaves = self.get_leaves(node)
-                for leaf in leaves:
-                    if (leaf.left_vertex, leaf.right_vertex) not in seen:
-                        leaf.adjacent_elements.append(to_add)
-                        seen.append((leaf.left_vertex, leaf.right_vertex))
-
-            next_element_rays = []
-            for i in six.moves.range(len(element_rays)):
-                for j in six.moves.range(i+1, len(element_rays)):
-                    if element_rays[i].midpoint is not None and element_rays[j].midpoint is not None:
-                        min_midpoint = min(element_rays[i].midpoint, element_rays[j].midpoint)
-                        max_midpoint = max(element_rays[i].midpoint, element_rays[j].midpoint)
-                        vertex_pair = (min_midpoint, max_midpoint)
-                        if vertex_pair in self.pair_map:
-                            next_element_rays.append(self.pair_map[vertex_pair])
-                            cur_next_rays = [element_rays[i], element_rays[j], self.pair_map[vertex_pair]]
-                            add_element_to_connectivity(cur_next_rays, to_add, seen)
-                        else:
-                            return
+                self.remove_from_subtree(node, new_hanging_vertex_elements, to_remove)
+            if split_possible:
+                next_element_rays = []
+                node_tuple_to_coord = {}
+                for node_index, node_tuple in enumerate(self.index_to_node_tuple[cur_dim]):
+                    node_tuple_to_coord[node_tuple] = vertices[node_index]
+                for midpoint_index, midpoint_tuple in enumerate(self.index_to_midpoint_tuple[cur_dim]):
+                    node_tuple_to_coord[midpoint_tuple] = midpoints[midpoint_index]
+                for i in six.moves.range(len(self.simplex_result[cur_dim])):
+                    next_vertices = []
+                    for j in six.moves.range(len(self.simplex_result[cur_dim][i])):
+                        next_vertices.append(node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]])
+                    all_rays_present = True
+                    for v1 in six.moves.range(len(next_vertices)):
+                        for v2 in six.moves.range(v1+1, len(next_vertices)):
+                            if (next_vertices[v1], next_vertices[v2]) not in self.pair_map:
+                                all_rays_present = False
+                    if all_rays_present:
+                        remove_element_from_connectivity(next_vertices, new_hanging_vertex_elements, to_remove)
+            else:
+                next_vertices_list = list(itertools.combinations(vertices, len(vertices)-1)) 
+                for next_vertices in next_vertices_list:
+                    remove_element_from_connectivity(next_vertices, new_hanging_vertex_elements, to_remove)
+
+        def add_element_to_connectivity(vertices, new_hanging_vertex_elements, to_add):
+            print vertices
+            import six
+            import itertools
+            if len(vertices) == 2:
+                min_vertex = min(vertices[0], vertices[1])
+                max_vertex = max(vertices[0], vertices[1])
+                ray = self.pair_map[(min_vertex, max_vertex)]
+                self.add_to_subtree(ray, new_hanging_vertex_elements, to_add)
+                return
+
+            cur_dim = len(vertices)-1
+            element_rays = []
+            midpoints = []
+            split_possible = True
+            for i in six.moves.range(len(vertices)):
+                for j in six.moves.range(i+1, len(vertices)):
+                    min_vertex = min(vertices[i], vertices[j])
+                    max_vertex = max(vertices[i], vertices[j])
+                    element_rays.append(self.pair_map[(min_vertex, max_vertex)])
+                    if element_rays[len(element_rays)-1].midpoint is not None:
+                        midpoints.append(element_rays[len(element_rays)-1].midpoint)
                     else:
-                        return
-            add_element_to_connectivity(next_element_rays, to_add, seen)
+                        split_possible = False
+            print midpoints
+            for node in element_rays:
+                self.add_to_subtree(node, new_hanging_vertex_elements, to_add)
+            if split_possible:
+                next_element_rays = []
+                node_tuple_to_coord = {}
+                for node_index, node_tuple in enumerate(self.index_to_node_tuple[cur_dim]):
+                    node_tuple_to_coord[node_tuple] = vertices[node_index]
+                for midpoint_index, midpoint_tuple in enumerate(self.index_to_midpoint_tuple[cur_dim]):
+                    node_tuple_to_coord[midpoint_tuple] = midpoints[midpoint_index]
+                for i in six.moves.range(len(self.simplex_result[cur_dim])):
+                    next_vertices = []
+                    for j in six.moves.range(len(self.simplex_result[cur_dim][i])):
+                        next_vertices.append(node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]])
+                    all_rays_present = True
+                    for v1 in six.moves.range(len(next_vertices)):
+                        for v2 in six.moves.range(v1+1, len(next_vertices)):
+                            if (next_vertices[v1], next_vertices[v2]) not in self.pair_map:
+                                all_rays_present = False
+                    if all_rays_present:
+                        add_element_to_connectivity(next_vertices, new_hanging_vertex_elements, to_add)
+            else:
+                next_vertices_list = list(itertools.combinations(vertices, len(vertices)-1)) 
+                for next_vertices in next_vertices_list:
+                    add_element_to_connectivity(next_vertices, new_hanging_vertex_elements, to_add)
+#            import six
+#            for node in element_rays:
+#                self.add_element_to_connectivity(node, new_hanging_vertex_elements, to_add)
+ #               leaves = self.get_subtree(node)
+ #               for leaf in leaves:
+ #                   if to_add not in leaf.adjacent_elements:
+ #                       leaf.adjacent_elements.append(to_add)
+ #                   if to_add not in new_hanging_vertex_elements[leaf.left_vertex]:
+ #                       new_hanging_vertex_elements[leaf.left_vertex].append(to_add)
+ #                   if to_add not in new_hanging_vertex_elements[leaf.right_vertex]:
+ #                       new_hanging_vertex_elements[leaf.right_vertex].append(to_add)
+
+#            next_element_rays = []
+#            for i in six.moves.range(len(element_rays)):
+#                for j in six.moves.range(i+1, len(element_rays)):
+#                    if element_rays[i].midpoint is not None and element_rays[j].midpoint is not None:
+#                        min_midpoint = min(element_rays[i].midpoint, element_rays[j].midpoint)
+#                        max_midpoint = max(element_rays[i].midpoint, element_rays[j].midpoint)
+#                        vertex_pair = (min_midpoint, max_midpoint)
+#                        if vertex_pair in self.pair_map:
+#                            next_element_rays.append(self.pair_map[vertex_pair])
+#                            cur_next_rays = []
+#                            if element_rays[i].left_vertex == element_rays[j].left_vertex:
+#                                cur_next_rays = [element_rays[i].left, element_rays[j].left, self.pair_map[vertex_pair]]
+#                            if element_rays[i].right_vertex == element_rays[j].right_vertex:
+#                                cur_next_rays = [element_rays[i].right, element_rays[j].right, self.pair_map[vertex_pair]]
+#                            if element_rays[i].left_vertex == element_rays[j].right_vertex:
+#                                cur_next_rays = [element_rays[i].left, element_rays[j].right, self.pair_map[vertex_pair]]
+#                            if element_rays[i].right_vertex == element_rays[j].left_vertex:
+#                                cur_next_rays = [element_rays[i].right, element_rays[j].left, self.pair_map[vertex_pair]]
+#                            assert (cur_next_rays != [])
+#                            #print cur_next_rays
+#                            add_element_to_connectivity(cur_next_rays, new_hanging_vertex_elements, to_add)
+#                        else:
+#                            return
+#                    else:
+#                        return
+#            add_element_to_connectivity(next_element_rays, new_hanging_vertex_elements, to_add)
 
         def add_hanging_vertex_el(v_index, el):
             assert not (v_index == 37 and el == 48)
@@ -321,6 +452,10 @@ class Refiner(object):
                             min_index = min(vertex_indices[i], vertex_indices[j])
                             max_index = max(vertex_indices[i], vertex_indices[j])
                             cur_node = self.pair_map[(min_index, max_index)]
+                            #print iel_base+iel_grp, cur_node.left_vertex, cur_node.right_vertex
+                            if (iel_base + iel_grp) not in cur_node.adjacent_elements:
+                                print min_index, max_index
+                                print iel_base + iel_grp, cur_node.left_vertex, cur_node.right_vertex, cur_node.adjacent_elements
                             assert ((iel_base+iel_grp) in cur_node.adjacent_elements)
 
         for i in six.moves.range(len(self.last_mesh.vertices)):
@@ -372,6 +507,7 @@ class Refiner(object):
                                         add_hanging_vertex_el(vertices_index, el)
                                 #compute midpoint coordinates
                                 for k in six.moves.range(len(self.last_mesh.vertices)):
+                                    print 'STUFF:', k, vertices_index, vertex_indices[i], vertex_indices[j]
                                     vertices[k, vertices_index] = \
                                     (self.last_mesh.vertices[k, vertex_indices[i]] +
                                     self.last_mesh.vertices[k, vertex_indices[j]]) / 2.0
@@ -393,204 +529,193 @@ class Refiner(object):
                                     if el != (iel_base + iel_grp) and el not in (
                                         vertex_elements[len(vertex_elements)-1]):
                                         vertex_elements[len(vertex_elements)-1].append(el)
-                                if (iel_base+iel_grp) in new_hanging_vertex_element[cur_midpoint]:
-                                    new_hanging_vertex_element[cur_midpoint].remove(iel_base+iel_grp)
+#                                if (iel_base+iel_grp) in new_hanging_vertex_element[cur_midpoint]:
+#                                    new_hanging_vertex_element[cur_midpoint].remove(iel_base+iel_grp)
                                 midpoint_vertices.append(cur_midpoint)
 
-                    #fix connectivity for elements
-                    unique_vertex_pairs = [
-                        (i, j) for i in range(len(vertex_indices)) for j in range(
-                            i+1, len(vertex_indices))]
-                    midpoint_index = 0
-                    for i, j in unique_vertex_pairs:
-                        min_index = min(vertex_indices[i], vertex_indices[j])
-                        max_index = max(vertex_indices[i], vertex_indices[j])
-                        element_indices_1 = []
-                        element_indices_2 = []
-                        for k_index, k, in enumerate(self.simplex_result):
-                            ituple_index = self.simplex_node_tuples.index(
-                                self.index_to_node_tuple[i])
-                            jtuple_index = self.simplex_node_tuples.index(
-                                self.index_to_node_tuple[j])
-                            midpoint_tuple_index = self.simplex_node_tuples.index(
-                                self.index_to_midpoint_tuple[midpoint_index])
-                            if ituple_index in k and midpoint_tuple_index in k:
-                                element_indices_1.append(k_index)
-                            if jtuple_index in k and midpoint_tuple_index in k:
-                                element_indices_2.append(k_index)
-                        midpoint_index += 1
-                        if min_index == vertex_indices[i]:
-                            min_element_index = element_indices_1
-                            max_element_index = element_indices_2
-                        else:
-                            min_element_index = element_indices_2
-                            max_element_index = element_indices_1
-                        vertex_pair = (min_index, max_index)
-                        cur_node = self.pair_map[vertex_pair]
-                        '''
-                        if cur_node.direction:
-                            first_element_index = min_element_index
-                            second_element_index = max_element_index
-                        else:
-                            first_element_index = max_element_index
-                            second_element_index = min_element_index
-                        '''
-                        first_element_index = min_element_index
-                        second_element_index = max_element_index
-                        queue = [cur_node.left]
-                        while queue:
-                            vertex = queue.pop(0)
-                            #if leaf node
-                            if vertex.left is None and vertex.right is None:
-                                node_elements = vertex.adjacent_elements
-
-                                remove_ray_el(node_elements, iel_base+iel_grp)
-                                #node_elements.remove(iel_base+iel_grp)
-                                for k in first_element_index:
-                                    if k == 0:
-                                        node_elements.append(iel_base+iel_grp)
-                                    else:
-                                        node_elements.append(iel_base+nelements_in_grp+k-1)
-                                if new_hanging_vertex_element[vertex.left_vertex] and \
-                                        new_hanging_vertex_element[vertex.left_vertex].count(
-                                        iel_base+iel_grp):
-                                    new_hanging_vertex_element[vertex.left_vertex].remove(
-                                        iel_base+iel_grp)
-                                    for k in first_element_index:
-                                        if k == 0:
-                                            el_to_add = iel_base+iel_grp
-                                        else:
-                                            el_to_add = iel_base+nelements_in_grp+k-1
-
-                                        add_hanging_vertex_el(vertex.left_vertex,
-                                                el_to_add)
-                                        del el_to_add
-
-                                if new_hanging_vertex_element[vertex.right_vertex] and \
-                                        new_hanging_vertex_element[vertex.right_vertex].count(
-                                        iel_base+iel_grp):
-                                    new_hanging_vertex_element[vertex.right_vertex].remove(
-                                        iel_base+iel_grp)
-                                    for k in first_element_index:
-                                        if k == 0:
-                                            el_to_add = iel_base+iel_grp
-                                        else:
-                                            el_to_add = iel_base+nelements_in_grp+k-1
-
-                                        add_hanging_vertex_el(vertex.right_vertex, el_to_add)
-                                        del el_to_add
-                            else:
-                                queue.append(vertex.left)
-                                queue.append(vertex.right)
-
-                        queue = [cur_node.right]
-                        while queue:
-                            vertex = queue.pop(0)
-                            #if leaf node
-                            if vertex.left is None and vertex.right is None:
-                                node_elements = vertex.adjacent_elements
-                                #node_elements.remove(iel_base+iel_grp)
-                                remove_ray_el(node_elements, iel_base+iel_grp)
-                                for k in second_element_index:
-                                    if k == 0:
-                                        node_elements.append(iel_base+iel_grp)
-                                    else:
-                                        node_elements.append(iel_base+nelements_in_grp+k-1)
-                                if new_hanging_vertex_element[vertex.left_vertex] and \
-                                    new_hanging_vertex_element[vertex.left_vertex].count(
-                                    iel_base+iel_grp):
-                                    new_hanging_vertex_element[vertex.left_vertex].remove(
-                                        iel_base+iel_grp)
-                                    for k in second_element_index:
-                                        if k == 0:
-                                            el_to_add = iel_base+iel_grp
-                                        else:
-                                            el_to_add = iel_base+nelements_in_grp+k-1
+#                    #fix connectivity for elements
+#                    unique_vertex_pairs = [
+#                        (i, j) for i in range(len(vertex_indices)) for j in range(
+#                            i+1, len(vertex_indices))]
+#                    midpoint_index = 0
+#                    for i, j in unique_vertex_pairs:
+#                        min_index = min(vertex_indices[i], vertex_indices[j])
+#                        max_index = max(vertex_indices[i], vertex_indices[j])
+#                        element_indices_1 = []
+#                        element_indices_2 = []
+#                        for k_index, k, in enumerate(self.simplex_result):
+#                            ituple_index = self.simplex_node_tuples.index(
+#                                self.index_to_node_tuple[i])
+#                            jtuple_index = self.simplex_node_tuples.index(
+#                                self.index_to_node_tuple[j])
+#                            midpoint_tuple_index = self.simplex_node_tuples.index(
+#                                self.index_to_midpoint_tuple[midpoint_index])
+#                            if ituple_index in k and midpoint_tuple_index in k:
+#                                element_indices_1.append(k_index)
+#                            if jtuple_index in k and midpoint_tuple_index in k:
+#                                element_indices_2.append(k_index)
+#                        midpoint_index += 1
+#                        if min_index == vertex_indices[i]:
+#                            min_element_index = element_indices_1
+#                            max_element_index = element_indices_2
+#                        else:
+#                            min_element_index = element_indices_2
+#                            max_element_index = element_indices_1
+#                        vertex_pair = (min_index, max_index)
+#                        cur_node = self.pair_map[vertex_pair]
+#                        '''
+#                        if cur_node.direction:
+#                            first_element_index = min_element_index
+#                            second_element_index = max_element_index
+#                        else:
+#                            first_element_index = max_element_index
+#                            second_element_index = min_element_index
+#                        '''
+#                        first_element_index = min_element_index
+#                        second_element_index = max_element_index
+#                        subtree = self.get_subtree(cur_node.left)
+#                        for vertex in subtree:
+#                            node_elements = vertex.adjacent_elements
+
+#                            '''
+#                            remove_ray_el(node_elements, iel_base+iel_grp)
+#                            #node_elements.remove(iel_base+iel_grp)
+#                            for k in first_element_index:
+#                                if k == 0:
+#                                    node_elements.append(iel_base+iel_grp)
+#                                else:
+#                                    node_elements.append(iel_base+nelements_in_grp+k-1)
+#                            '''
+#                            if new_hanging_vertex_element[vertex.left_vertex] and \
+#                                    new_hanging_vertex_element[vertex.left_vertex].count(
+#                                    iel_base+iel_grp):
+#                                new_hanging_vertex_element[vertex.left_vertex].remove(
+#                                    iel_base+iel_grp)
+#                                for k in first_element_index:
+#                                    if k == 0:
+#                                        el_to_add = iel_base+iel_grp
+#                                    else:
+#                                        el_to_add = iel_base+nelements_in_grp+k-1
+#
+#                                    add_hanging_vertex_el(vertex.left_vertex,
+#                                            el_to_add)
+#                                    del el_to_add
+#
+#                            if new_hanging_vertex_element[vertex.right_vertex] and \
+#                                    new_hanging_vertex_element[vertex.right_vertex].count(
+#                                    iel_base+iel_grp):
+#                                new_hanging_vertex_element[vertex.right_vertex].remove(
+#                                    iel_base+iel_grp)
+#                                for k in first_element_index:
+#                                    if k == 0:
+#                                        el_to_add = iel_base+iel_grp
+#                                    else:
+#                                        el_to_add = iel_base+nelements_in_grp+k-1
+#
+#                                    add_hanging_vertex_el(vertex.right_vertex, el_to_add)
+#                                    del el_to_add
+
+#                        subtree = self.get_subtree(cur_node.right)
+#                        for vertex in subtree:
+#                            node_elements = vertex.adjacent_elements
+#                            #node_elements.remove(iel_base+iel_grp)
+#                            '''
+#                            remove_ray_el(node_elements, iel_base+iel_grp)
+#                            for k in second_element_index:
+#                                if k == 0:
+#                                    node_elements.append(iel_base+iel_grp)
+#                                else:
+#                                    node_elements.append(iel_base+nelements_in_grp+k-1)
+#                            '''
+#                            if new_hanging_vertex_element[vertex.left_vertex] and \
+#                                new_hanging_vertex_element[vertex.left_vertex].count(
+#                                iel_base+iel_grp):
+#                                new_hanging_vertex_element[vertex.left_vertex].remove(
+#                                    iel_base+iel_grp)
+#                                for k in second_element_index:
+#                                    if k == 0:
+#                                        el_to_add = iel_base+iel_grp
+#                                    else:
+#                                        el_to_add = iel_base+nelements_in_grp+k-1
+#
+#                                    add_hanging_vertex_el(vertex.left_vertex, el_to_add)
+#
+#                                    del el_to_add
+#
+#                            if new_hanging_vertex_element[vertex.right_vertex] and \
+#                                new_hanging_vertex_element[vertex.right_vertex].count(
+#                                iel_base+iel_grp):
+#                                new_hanging_vertex_element[vertex.right_vertex].remove(
+#                                    iel_base+iel_grp)
+#                                for k in second_element_index:
+#                                    if k == 0:
+#                                        el_to_add = iel_base+iel_grp
+#                                    else:
+#                                        el_to_add = iel_base+nelements_in_grp+k-1
+#
+#                                    add_hanging_vertex_el(vertex.right_vertex, el_to_add)
+#                                    del el_to_add
 
-                                        add_hanging_vertex_el(vertex.left_vertex, el_to_add)
-
-                                        del el_to_add
-
-                                if new_hanging_vertex_element[vertex.right_vertex] and \
-                                    new_hanging_vertex_element[vertex.right_vertex].count(
-                                    iel_base+iel_grp):
-                                    new_hanging_vertex_element[vertex.right_vertex].remove(
-                                        iel_base+iel_grp)
-                                    for k in second_element_index:
-                                        if k == 0:
-                                            el_to_add = iel_base+iel_grp
-                                        else:
-                                            el_to_add = iel_base+nelements_in_grp+k-1
-
-                                        add_hanging_vertex_el(vertex.right_vertex, el_to_add)
-                                        del el_to_add
-                            else:
-                                queue.append(vertex.left)
-                                queue.append(vertex.right)
                     #update connectivity of edges in the center
-                    unique_vertex_pairs = [
-                        (i,j) for i in range(len(midpoint_vertices)) for j in range(i+1,
-                            len(midpoint_vertices))]
-                    midpoint_index = 0
-                    for i, j in unique_vertex_pairs:
-                        min_index = min(midpoint_vertices[i], midpoint_vertices[j])
-                        max_index = max(midpoint_vertices[i], midpoint_vertices[j])
-                        vertex_pair = (min_index, max_index)
-                        if vertex_pair not in self.pair_map:
-                            continue
-                        element_indices = []
-                        for k_index, k in enumerate(self.simplex_result):
-                            ituple_index = self.simplex_node_tuples.index(
-                                self.index_to_midpoint_tuple[i])
-                            jtuple_index = self.simplex_node_tuples.index(
-                                self.index_to_midpoint_tuple[j])
-                            if ituple_index in k and jtuple_index in k:
-                                element_indices.append(k_index)
-
-                        cur_node = self.pair_map[vertex_pair]
-                        queue = [cur_node]
-                        while queue:
-                            vertex = queue.pop(0)
-                            #if leaf node
-                            if vertex.left is None and vertex.right is None:
-                                node_elements = vertex.adjacent_elements
-                                print iel_base+iel_grp
-                                node_elements.remove(iel_base+iel_grp)
-                                for k in element_indices:
-                                    if k == 0:
-                                        node_elements.append(iel_base+iel_grp)
-                                    else:
-                                        node_elements.append(iel_base+nelements_in_grp+k-1)
-                                if new_hanging_vertex_element[vertex.left_vertex] and \
-                                    new_hanging_vertex_element[vertex.left_vertex].count(
-                                    iel_base+iel_grp):
-                                    new_hanging_vertex_element[vertex.left_vertex].remove(
-                                        iel_base+iel_grp)
-                                    for k in element_indices:
-                                        if k == 0:
-                                            el_to_add = iel_base+iel_grp
-                                        else:
-                                            el_to_add = iel_base+nelements_in_grp+k-1
-                                        add_hanging_vertex_el(vertex.left_vertex, el_to_add)
-                                        del el_to_add
-                                        
-                                if new_hanging_vertex_element[vertex.right_vertex] and \
-                                    new_hanging_vertex_element[vertex.right_vertex].count(
-                                    iel_base+iel_grp):
-                                    new_hanging_vertex_element[vertex.right_vertex].remove(
-                                        iel_base+iel_grp)
-                                    for k in second_element_index:
-                                        if k == 0:
-                                            el_to_add = iel_base+iel_grp
-                                        else:
-                                            el_to_add = iel_base+nelements_in_grp+k-1
-
-                                        add_hanging_vertex_el(vertex.right_vertex, el_to_add)
-                                        del el_to_add
-                            else:
-                                queue.append(vertex.left)
-                                queue.append(vertex.right)
+#                    unique_vertex_pairs = [
+#                        (i,j) for i in range(len(midpoint_vertices)) for j in range(i+1,
+#                            len(midpoint_vertices))]
+#                    midpoint_index = 0
+#                    for i, j in unique_vertex_pairs:
+#                        min_index = min(midpoint_vertices[i], midpoint_vertices[j])
+#                        max_index = max(midpoint_vertices[i], midpoint_vertices[j])
+#                        vertex_pair = (min_index, max_index)
+#                        if vertex_pair not in self.pair_map:
+#                            continue
+#                        element_indices = []
+#                        for k_index, k in enumerate(self.simplex_result):
+#                            ituple_index = self.simplex_node_tuples.index(
+#                                self.index_to_midpoint_tuple[i])
+#                            jtuple_index = self.simplex_node_tuples.index(
+#                                self.index_to_midpoint_tuple[j])
+#                            if ituple_index in k and jtuple_index in k:
+#                                element_indices.append(k_index)
+
+#                        cur_node = self.pair_map[vertex_pair]
+#                        subtree = self.get_subtree(cur_node)
+#                        for vertex in subtree:
+#                            node_elements = vertex.adjacent_elements
+#                            #print iel_base+iel_grp
+#                            node_elements.remove(iel_base+iel_grp)
+#                            for k in element_indices:
+#                                if k == 0:
+#                                    node_elements.append(iel_base+iel_grp)
+#                                else:
+#                                    node_elements.append(iel_base+nelements_in_grp+k-1)
+#                            if new_hanging_vertex_element[vertex.left_vertex] and \
+#                                new_hanging_vertex_element[vertex.left_vertex].count(
+#                                iel_base+iel_grp):
+#                                new_hanging_vertex_element[vertex.left_vertex].remove(
+#                                    iel_base+iel_grp)
+#                                for k in element_indices:
+#                                    if k == 0:
+#                                        el_to_add = iel_base+iel_grp
+#                                    else:
+#                                        el_to_add = iel_base+nelements_in_grp+k-1
+#                                    add_hanging_vertex_el(vertex.left_vertex, el_to_add)
+#                                    del el_to_add
+#                                    
+#                            if new_hanging_vertex_element[vertex.right_vertex] and \
+#                                new_hanging_vertex_element[vertex.right_vertex].count(
+#                                iel_base+iel_grp):
+#                                new_hanging_vertex_element[vertex.right_vertex].remove(
+#                                    iel_base+iel_grp)
+#                                for k in second_element_index:
+#                                    if k == 0:
+#                                        el_to_add = iel_base+iel_grp
+#                                    else:
+#                                        el_to_add = iel_base+nelements_in_grp+k-1
+#
+#                                    add_hanging_vertex_el(vertex.right_vertex, el_to_add)
+#                                    del el_to_add
+
                     #generate new rays
+                    cur_dim = len(grp.vertex_indices[0])-1
                     for i in six.moves.range(len(midpoint_vertices)):
                         for j in six.moves.range(i+1, len(midpoint_vertices)):
                             min_index = min(midpoint_vertices[i], midpoint_vertices[j])
@@ -598,86 +723,105 @@ class Refiner(object):
                             vertex_pair = (min_index, max_index)
                             if vertex_pair in self.pair_map:
                                 continue
-                            elements = []
-                            common_elements = list(set(new_hanging_vertex_element[min_index]).
-                                intersection(new_hanging_vertex_element[max_index]))
-                            for cel in common_elements:
-                                elements.append(cel)
-                            vertex_1_index = self.simplex_node_tuples.index(
-                                self.index_to_midpoint_tuple[i])
-                            vertex_2_index = self.simplex_node_tuples.index(
-                                self.index_to_midpoint_tuple[j])
-                            for kind, k in enumerate(self.simplex_result):
-                                if vertex_1_index in k and vertex_2_index in k:
-                                    if kind == 0:
-                                        elements.append(iel_base+iel_grp)
-                                    else:
-                                        elements.append(iel_base+nelements_in_grp+kind-1)
+#                            elements = []
+#                            common_elements = list(set(new_hanging_vertex_element[min_index]).
+#                                intersection(new_hanging_vertex_element[max_index]))
+#                            for cel in common_elements:
+#                                elements.append(cel)
+#                            vertex_1_index = self.simplex_node_tuples[cur_dim].index(
+#                                self.index_to_midpoint_tuple[i])
+#                            vertex_2_index = self.simplex_node_tuples[cur_dim].index(
+#                                self.index_to_midpoint_tuple[j])
+                            
+#                            for kind, k in enumerate(self.simplex_result):
+#                                if vertex_1_index in k and vertex_2_index in k:
+#                                    if kind == 0:
+#                                        elements.append(iel_base+iel_grp)
+#                                    else:
+#                                        elements.append(iel_base+nelements_in_grp+kind-1)
+#                            #print min_index, max_index, elements
                             self.pair_map[vertex_pair] = TreeRayNode(min_index, max_index,
-                                    True, elements)
+                                    True, [])
                     node_tuple_to_coord = {}
-                    for node_index, node_tuple in enumerate(self.index_to_node_tuple):
+                    for node_index, node_tuple in enumerate(self.index_to_node_tuple[cur_dim]):
                         node_tuple_to_coord[node_tuple] = grp.vertex_indices[iel_grp][node_index]
-                    for midpoint_index, midpoint_tuple in enumerate(self.index_to_midpoint_tuple):
+                    for midpoint_index, midpoint_tuple in enumerate(self.index_to_midpoint_tuple[cur_dim]):
                         node_tuple_to_coord[midpoint_tuple] = midpoint_vertices[midpoint_index]
-                    for i in six.moves.range(len(self.simplex_result)):
-                        for j in six.moves.range(len(self.simplex_result[i])):
+                    for i in six.moves.range(len(self.simplex_result[cur_dim])):
+                        for j in six.moves.range(len(self.simplex_result[cur_dim][i])):
                             if i == 0:
+                                print node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
+                                print 'GRP:', groups[grpn][iel_grp]
                                 groups[grpn][iel_grp][j] = \
-                                        node_tuple_to_coord[self.simplex_node_tuples[self.simplex_result[i][j]]]
+                                        node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
                             else:
+                                print self.simplex_result[cur_dim][i][j]
+                                print node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
+                                print i, j, cur_dim
                                 groups[grpn][nelements_in_grp+i-1][j] = \
-                                        node_tuple_to_coord[self.simplex_node_tuples[self.simplex_result[i][j]]]
-                    
+                                        node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
                     #update tet connectivity
 
                     #remove from connectivity
-                    if len(grp.vertex_indices[0]) == 4:
-                        seen_rays = []
-                        for tup_index, tup in enumerate(self.simplex_result):
-                            three_vertex_tuples = [
-                                    (i, j, k) for i in range(len(tup)) for j in range(i+1, len(tup))
-                                    for k in range(j+1, len(tup))]
-                            for i, j, k in three_vertex_tuples:
-                                vertex_i = node_tuple_to_coord[self.simplex_node_tuples[tup[i]]]
-                                vertex_j = node_tuple_to_coord[self.simplex_node_tuples[tup[j]]]
-                                vertex_k = node_tuple_to_coord[self.simplex_node_tuples[tup[k]]]
-                                element_rays = []
-                                element_rays.append(self.pair_map[(
-                                    min(vertex_i, vertex_j), max(vertex_i, vertex_j))])
-                                element_rays.append(self.pair_map[(
-                                    min(vertex_i, vertex_k), max(vertex_i, vertex_k))])
-                                element_rays.append(self.pair_map[(
-                                    min(vertex_j, vertex_k), max(vertex_j, vertex_k))])
-                                remove_element_from_connectivity(element_rays, iel_base+iel_grp,
-                                        seen_rays)
+                    #if len(grp.vertex_indices[0]) == 4:
+                    for tup_index, tup in enumerate(self.simplex_result[cur_dim]):
+                        rem_vertices = []
+                        for vert in tup:
+                            rem_vertices.append(node_tuple_to_coord[self.simplex_node_tuples[cur_dim][vert]])
+                        remove_element_from_connectivity(rem_vertices, new_hanging_vertex_element, iel_base+iel_grp)
+#                        three_vertex_tuples = [
+#                                (i, j, k) for i in range(len(tup)) for j in range(i+1, len(tup))
+#                                for k in range(j+1, len(tup))]
+#                        for i, j, k in three_vertex_tuples:
+                        
+#                        vertex_i = node_tuple_to_coord[self.simplex_node_tuples[tup[i]]]
+#                        vertex_j = node_tuple_to_coord[self.simplex_node_tuples[tup[j]]]
+#                        vertex_k = node_tuple_to_coord[self.simplex_node_tuples[tup[k]]]
+#                        element_rays.append(self.pair_map[(
+#                            min(vertex_i, vertex_j), max(vertex_i, vertex_j))])
+#                        element_rays.append(self.pair_map[(
+#                            min(vertex_i, vertex_k), max(vertex_i, vertex_k))])
+#                        element_rays.append(self.pair_map[(
+#                            min(vertex_j, vertex_k), max(vertex_j, vertex_k))])
 
                         #add to connectivity
-                        for tup_index, tup in enumerate(self.simplex_result):
-                            seen_rays = []
-                            three_vertex_tuples = [
-                                    (i, j, k) for i in range(len(tup)) for j in range(i+1, len(tup))
-                                    for k in range(j+1, len(tup))]
-                            for i, j, k in three_vertex_tuples:
-                                vertex_i = node_tuple_to_coord[self.simplex_node_tuples[tup[i]]]
-                                vertex_j = node_tuple_to_coord[self.simplex_node_tuples[tup[j]]]
-                                vertex_k = node_tuple_to_coord[self.simplex_node_tuples[tup[k]]]
-                                element_rays = []
-                                element_rays.append(self.pair_map[(
-                                    min(vertex_i, vertex_j), max(vertex_i, vertex_j))])
-                                element_rays.append(self.pair_map[(
-                                    min(vertex_i, vertex_k), max(vertex_i, vertex_k))])
-                                element_rays.append(self.pair_map[(
-                                    min(vertex_j, vertex_k), max(vertex_j, vertex_k))])
-                                if tup_index != 0:
-                                    add_element_to_connectivity(element_rays,
-                                            nelements_in_grp+tup_index-1, seen_rays)
-                                else:
-                                    add_element_to_connectivity(element_rays, iel_base+iel_grp,
-                                            update_seen)
-                    nelements_in_grp += len(self.simplex_result)-1
+                        for tup_index, tup in enumerate(self.simplex_result[cur_dim]):
+                            add_vertices = []
+                            print 'TUP:', tup
+                            for vert in tup:
+                                add_vertices.append(node_tuple_to_coord[self.simplex_node_tuples[cur_dim][vert]])
+                            if tup_index == 0:
+                                add_element_to_connectivity(add_vertices, new_hanging_vertex_element,
+                                        iel_base+iel_grp)
+                            else:
+                                add_element_to_connectivity(add_vertices, new_hanging_vertex_element,
+                                    nelements_in_grp+tup_index-1)
+#                            three_vertex_tuples = [
+#                                    (i, j, k) for i in range(len(tup)) for j in range(i+1, len(tup))
+#                                    for k in range(j+1, len(tup))]
+#                            for i, j, k in three_vertex_tuples:
+#                                vertex_i = node_tuple_to_coord[self.simplex_node_tuples[tup[i]]]
+#                                vertex_j = node_tuple_to_coord[self.simplex_node_tuples[tup[j]]]
+#                                vertex_k = node_tuple_to_coord[self.simplex_node_tuples[tup[k]]]
+#                                element_rays = []
+#                                element_rays.append(self.pair_map[(
+#                                    min(vertex_i, vertex_j), max(vertex_i, vertex_j))])
+#                                element_rays.append(self.pair_map[(
+#                                    min(vertex_i, vertex_k), max(vertex_i, vertex_k))])
+#                                element_rays.append(self.pair_map[(
+#                                    min(vertex_j, vertex_k), max(vertex_j, vertex_k))])
+#                                if tup_index != 0:
+#                                    if (nelements_in_grp+tup_index-1) == 263:
+#                                        print "VS:", vertex_i, vertex_j, vertex_k
+#                                        for ray in element_rays:
+#                                            print ray.left_vertex, ray.right_vertex
+#                                    add_element_to_connectivity(element_rays, new_hanging_vertex_element,
+#                                            nelements_in_grp+tup_index-1)
+#                                else:
+#                                    add_element_to_connectivity(element_rays, new_hanging_vertex_element, iel_base+iel_grp)
+                    nelements_in_grp += len(self.simplex_result[cur_dim])-1
                     #assert ray connectivity
-                    #check_adjacent_elements(groups, nelements_in_grp)
+                    check_adjacent_elements(groups, nelements_in_grp)
 
 
 
-- 
GitLab