Skip to content
Snippets Groups Projects
Commit 28691b12 authored by Ellis Hoag's avatar Ellis Hoag
Browse files

partition_mesh handles multiple MeshElementGroups

parent 90346bed
No related branches found
No related tags found
1 merge request!9Master
......@@ -56,6 +56,7 @@ def partition_mesh(mesh, part_per_element, part_nr):
"""
assert len(part_per_element) == mesh.nelements
# Contains the indices of the elements requested.
queried_elems = np.where(np.array(part_per_element) == part_nr)[0]
num_groups = len(mesh.groups)
......@@ -68,39 +69,61 @@ def partition_mesh(mesh, part_per_element, part_nr):
start_idx = 0
for group_nr in range(num_groups):
mesh_group = mesh.groups[group_nr]
#end_idx = start_idx + np.where(queried_elems >= mesh_group.nelements - start_idx)[0] - 1
# Find the index of first element in the next group
end_idx = len(queried_elems)
for idx in range(len(queried_elems)):
if queried_elems[idx] - start_idx > mesh_group.nelements:
end_idx = idx
break
new_indices.append(mesh_group.vertex_indices[queried_elems[start_idx:end_idx]])
elems = queried_elems[start_idx:end_idx]
new_indices.append(mesh_group.vertex_indices[elems])
new_nodes.append(np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes)))
new_nodes.append(
np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes)))
for i in range(mesh.ambient_dim):
for j in range(start_idx, end_idx):
new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, queried_elems[j], :]
elems = queried_elems[j]
new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, elems, :]
index_set |= set(new_indices[group_nr].ravel())
start_idx = end_idx
# A sorted np.array of vertex indices we need in our new mesh (without duplicates).
# A sorted np.array of vertex indices we need (without duplicates).
required_indices = np.array(list(index_set))
new_vertices = np.zeros((mesh.ambient_dim, len(required_indices)))
for dim in range(mesh.ambient_dim):
new_vertices[dim] = mesh.vertices[dim][required_indices]
# We need to update our indices to be in range [0, len(mesh_group.nelements)].
# Our indices need to be in range [0, len(mesh_group.nelements)].
for group_nr in range(num_groups):
for i in range(len(new_indices[group_nr])):
for j in range(len(new_indices[group_nr][0])):
new_indices[group_nr][i, j] = np.where(required_indices == new_indices[group_nr][i, j])[0]
original_index = new_indices[group_nr][i, j]
new_indices[group_nr][i, j] = np.where(
required_indices == original_index)[0]
"""
print("mesh vertices: ", mesh.vertices)
print("mesh indices: ", mesh.groups[0].vertex_indices)
print("mesh nodes: ", mesh.groups[0].nodes)
print("queried_elems: ", queried_elems)
print("indices: ", new_indices[0])
print("nodes: ", new_nodes[0])
print("vertices: ", new_vertices)
"""
from meshmode.mesh import MeshElementGroup, Mesh
new_mesh_groups = []
for group_nr in range(num_groups):
mesh_group = mesh.groups[group_nr]
new_mesh_groups.append(MeshElementGroup(mesh_group.order, new_indices[group_nr], new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes, dim=mesh_group.dim))
new_mesh_groups.append(
MeshElementGroup(mesh_group.order, new_indices[group_nr],
new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
part_mesh = Mesh(new_vertices, new_mesh_groups)
......
......@@ -47,38 +47,44 @@ import pytest
import logging
logger = logging.getLogger(__name__)
# {{{ partition_mesh
@pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"])
@pytest.mark.parametrize("npoints", [10, 1000])
def test_partition_mesh(mesh_type=None, npoints=None):
#@pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"])
#@pytest.mark.parametrize("npoints", [10, 1000])
def test_partition_mesh(mesh_type, npoints):
from meshmode.mesh.generation import generate_torus
my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2)
part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0])
print(my_mesh.nelements, my_mesh.groups[0].nelements)
#(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1)
from meshmode.mesh.processing import partition_mesh
partition_mesh(my_mesh, part_per_element, 1)
"""from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish
(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1)
"""
from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish
if mesh_type == "cloverleaf":
mesh = make_curve_mesh(cloverleaf, np.linspace(0, 1, npoints), order=3)
elif mesh_type == "starfish":
mesh = make_curve_mesh(starfish, np.linspace(0, 1, npoints), order=3)
#from meshmode.mesh.visualization import draw_2d_mesh
#draw_2d_mesh(mesh)
#import matplotlib.pyplot as pt
#pt.show()
#TODO: Create an actuall adjacency list from the mesh.
adjacency_list = mesh.nodal_adjacency
adjacency_list = mesh.facial_adjacency_groups
print(adjacency_list)
from pymetis import part_graph
part_per_element = np.array(part_graph(3, adjacency=adjacency_list))
from meshmode.mesh.processing import partition_mesh
(part_mesh, part_to_global) = partition_mesh(mesh, part_per_element, 0)
"""
"""
# }}}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment