diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index 7f520f44bf583c7a57ef3b8b5383884c144c4e63..05ce910e58ba4eac79c57a3194091bec1e5ea962 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -56,6 +56,7 @@ def partition_mesh(mesh, part_per_element, part_nr): """ assert len(part_per_element) == mesh.nelements + # Contains the indices of the elements requested. queried_elems = np.where(np.array(part_per_element) == part_nr)[0] num_groups = len(mesh.groups) @@ -68,39 +69,61 @@ def partition_mesh(mesh, part_per_element, part_nr): start_idx = 0 for group_nr in range(num_groups): mesh_group = mesh.groups[group_nr] - #end_idx = start_idx + np.where(queried_elems >= mesh_group.nelements - start_idx)[0] - 1 + + # Find the index of first element in the next group end_idx = len(queried_elems) + for idx in range(len(queried_elems)): + if queried_elems[idx] - start_idx > mesh_group.nelements: + end_idx = idx + break - new_indices.append(mesh_group.vertex_indices[queried_elems[start_idx:end_idx]]) + elems = queried_elems[start_idx:end_idx] + new_indices.append(mesh_group.vertex_indices[elems]) - new_nodes.append(np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes))) + new_nodes.append( + np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes))) for i in range(mesh.ambient_dim): for j in range(start_idx, end_idx): - new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, queried_elems[j], :] + elems = queried_elems[j] + new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, elems, :] index_set |= set(new_indices[group_nr].ravel()) start_idx = end_idx - # A sorted np.array of vertex indices we need in our new mesh (without duplicates). + # A sorted np.array of vertex indices we need (without duplicates). required_indices = np.array(list(index_set)) new_vertices = np.zeros((mesh.ambient_dim, len(required_indices))) for dim in range(mesh.ambient_dim): new_vertices[dim] = mesh.vertices[dim][required_indices] - # We need to update our indices to be in range [0, len(mesh_group.nelements)]. + # Our indices need to be in range [0, len(mesh_group.nelements)]. for group_nr in range(num_groups): for i in range(len(new_indices[group_nr])): for j in range(len(new_indices[group_nr][0])): - new_indices[group_nr][i, j] = np.where(required_indices == new_indices[group_nr][i, j])[0] + original_index = new_indices[group_nr][i, j] + new_indices[group_nr][i, j] = np.where( + required_indices == original_index)[0] + + """ + print("mesh vertices: ", mesh.vertices) + print("mesh indices: ", mesh.groups[0].vertex_indices) + print("mesh nodes: ", mesh.groups[0].nodes) + print("queried_elems: ", queried_elems) + print("indices: ", new_indices[0]) + print("nodes: ", new_nodes[0]) + print("vertices: ", new_vertices) + """ from meshmode.mesh import MeshElementGroup, Mesh new_mesh_groups = [] for group_nr in range(num_groups): mesh_group = mesh.groups[group_nr] - new_mesh_groups.append(MeshElementGroup(mesh_group.order, new_indices[group_nr], new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes, dim=mesh_group.dim)) + new_mesh_groups.append( + MeshElementGroup(mesh_group.order, new_indices[group_nr], + new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes)) part_mesh = Mesh(new_vertices, new_mesh_groups) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index b5b2cde37195eee8e8398ea4de808a76224f22ae..99a9284e291ffd661cb0b0b878e6b08ade531945 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -47,38 +47,44 @@ import pytest import logging logger = logging.getLogger(__name__) + # {{{ partition_mesh -@pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"]) -@pytest.mark.parametrize("npoints", [10, 1000]) -def test_partition_mesh(mesh_type=None, npoints=None): +#@pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"]) +#@pytest.mark.parametrize("npoints", [10, 1000]) +def test_partition_mesh(mesh_type, npoints): from meshmode.mesh.generation import generate_torus - my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2) part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0]) - print(my_mesh.nelements, my_mesh.groups[0].nelements) - - #(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1) from meshmode.mesh.processing import partition_mesh - partition_mesh(my_mesh, part_per_element, 1) - """from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish + (part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1) + + + """ + from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish if mesh_type == "cloverleaf": mesh = make_curve_mesh(cloverleaf, np.linspace(0, 1, npoints), order=3) elif mesh_type == "starfish": mesh = make_curve_mesh(starfish, np.linspace(0, 1, npoints), order=3) + #from meshmode.mesh.visualization import draw_2d_mesh + #draw_2d_mesh(mesh) + #import matplotlib.pyplot as pt + #pt.show() + #TODO: Create an actuall adjacency list from the mesh. - adjacency_list = mesh.nodal_adjacency + adjacency_list = mesh.facial_adjacency_groups + print(adjacency_list) from pymetis import part_graph part_per_element = np.array(part_graph(3, adjacency=adjacency_list)) from meshmode.mesh.processing import partition_mesh (part_mesh, part_to_global) = partition_mesh(mesh, part_per_element, 0) -""" + """ # }}}