diff --git a/meshmode/mesh/refinement.py b/meshmode/mesh/refinement.py index 7ea840c11fa778eb8a84e22511246c622179c583..0ea9694402a7cec04128650234508799bc598000 100644 --- a/meshmode/mesh/refinement.py +++ b/meshmode/mesh/refinement.py @@ -52,7 +52,7 @@ class Refiner(object): def __init__(self, mesh): from llist import dllist, dllistnode from meshmode.mesh.tesselate import tesselatetet - self.tri_node_tuples, self.tri_result = tesselatetet() + self.simplex_node_tuples, self.simplex_result = tesselatetet() self.last_mesh = mesh # {{{ initialization @@ -196,25 +196,135 @@ class Refiner(object): #refine_flag tells you which elements to split as a numpy array of bools def refine(self, refine_flags): import six + import numpy as np + from sets import Set + #vertices and groups for next generation nvertices = self.last_mesh.nvertices + + vertices = np.empty([len(self.last_mesh.vertices), nvertices]) + groups = [] + + midpoint_already = Set([]) + grpn = 0 + totalnelements = 0 + new_hanging_vertex_element = [] + + for i in six.moves.range(nvertices): + new_hanging_vertex_element.append([]) + for grp in self.last_mesh.groups: iel_base = grp.element_nr_base + nelements = 0 for iel_grp in six.moves.range(grp.nelements): + nelements += 1 + vertex_indices = grp.vertex_indices[iel_grp] if refine_flags[iel_base+iel_grp]: + nelements += len(self.simplex_result) - 1 + for i in six.moves.range(len(vertex_indices)): + for j in six.moves.range(i+1, len(vertex_indices)): + i_index = vertex_indices[i] + j_index = vertex_indices[j] + index_tuple = (i_index, j_index) if i_index < j_index else (j_index, i_index) + if index_tuple not in midpoint_already and \ + self.pair_map[vertex_pair].midpoint is None: + nvertices += 1 + midpoint_already.add(index_tuple) + groups.append(np.empty([nelements, len(grp.vertex_indices[0])], dtype=np.int32)) + grpn += 1 + totalnelements += nelements + for i in six.moves.range(len(self.last_mesh.vertices)): + for j in six.moves.range(len(self.last_mesh.vertices[i])): + vertices[i,j] = self.last_mesh.vertices[i,j] + import copy + if i == 0: + new_hanging_vertex_element[j] = copy.deepcopy(self.hanging_vertex_element[j]) + grpn = 0 + for grp in self.last_mesh.groups: + for iel_grp in six.moves.range(grp.nelements): + for i in six.moves.range(len(grp.vertex_indices[iel_grp])): + groups[grpn][iel_grp][i] = grp.vertex_indices[iel_grp][i] + grpn += 1 + + grpn = 0 + vertices_index = self.last_mesh.nvertices + for grp in self.last_mesh.groups: + iel_base = grp.element_nr_base + for iel_grp in six.moves.range(grp.nelements): + if refine_flags[iel_base+iel_grp]: + midpoint_vertices = [] + midpoint_tuples = [] + vertex_elements = [] vertex_indices = grp.vertex_indices[iel_grp] for i in six.moves.range(len(vertex_indices)): for j in six.moves.range(i+1, len(vertex_indices)): + vertex_elements.append([]) min_index = min(vertex_indices[i], vertex_indices[j]) max_index = max(vertex_indices[i], vertex_indices[j]) cur_node = self.pair_map[(min_index, max_index)] if cur_node.midpoint is None: - cur_node.midpoint = nvertices - cur_node.left = TreeRayNode(cur_node.direction, cur_node.adjacent_elements) - cur_node.right = TreeRayNode(not cur_node.direction, cur_node.adjacent_elements) - left_index = cur_node.left.adjacent_elements.index(iel_base+iel_grp) - #right_index = cur_node - #cur_node.left.adjacent_elements[left_index] = - nvertices += 1 + cur_node.midpoint = vertices_index + import copy + cur_node.left = TreeRayNode(cur_node.direction, + copy.deepcopy(cur_node.adjacent_elements)) + cur_node.right = TreeRayNode(not cur_node.direction, + copy.deepcopy(cur_node.adjacent_elements)) + vertex_pair1 = (min_index, vertices_index) + vertex_pair2 = (max_index, vertices_index) + self.pair_map[vertex_pair1] = cur_node.left + self.pair_map[vertex_pair2] = cur_node.right + for el in cur_node.adjacent_elements: + if el != (iel_base+iel_grp): + vertex_elements[len(vertex_elements)-1].append(el) + new_hanging_vertex_element[vertices_index].append(el) + #compute midpoint coordinates + for k in six.moves.range(len(self.last_mesh.vertices)): + vertices[k, vertices_index] = \ + (self.last_mesh.vertices[k, vertex_indices[i]] + + self.last_mesh.vertices[k, vertex_indices[j]]) / 2.0 + midpoint_vertices.append(vertices_index) + vertices_index += 1 + else: + cur_midpoint = cur_node.midpoint + elements = cur_node.adjacent_elements + for el in elements: + if el != (iel_base + iel_grp) and el not in ( + vertex_elements[len(vertex_elements)-1]: + vertex_elements[len(vertex_elements)-1].append(el) + for el in new_hanging_vertex_element[cur_midpoint]: + if el != (iel_base + iel_grp) and el not in ( + vertex_elements[len(vertex_elements)-1]: + vertex_elements[len(vertex_elements)-1].append(el) + midpoint_vertices.append(cur_midpoint) + + + unique_vertex_pairs = [ + (i, j) for i in range(len(midpoint_vertices)) for j in range( + i+1, len(midpoint_vertices))] + midpoint_index = 0 + for i, j in unique_vertex_pairs: + min_index = min(midpoint_vertices[i], midpoint_vertices[j]) + max_index = max(midpoint_vertices[i], midpoint_vertices[j]) + vertex_pair = (min_index, max_index) + if vertex_pair not in self.pair_map: + continue + element_indices = [] + for k_index, k in enumerate(self.simplex_result): + ituple_index = self.simplex_node_tuples.index( + self.index_to_midpoint_tuple[i]) + jtuple_index = self.simplex_node_tuples.index( + self.index_to_midpoint_tuple[j]) + if ituple_index in k and jtuple_index in k: + element_indices.append(k_index) + + cur_node = self.pair_map[vertex_pair] + cur_node.adjacent_elements.remove(iel_base+iel_grp) + for k in element_indices: + if k == 0: + cur_node.adjacent_elements.append(iel_base+iel_grp) + else: + cur_node.adjacent_elements.append(iel_base+grp.nelements + + k - 1) + def print_rays(self, ind):