diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index 05ce910e58ba4eac79c57a3194091bec1e5ea962..02a7e3bc18f7697244f80cdc76128fde9ae7104f 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -54,7 +54,7 @@ def partition_mesh(mesh, part_per_element, part_nr): *part_to_global* is a :class:`numpy.ndarray` mapping element numbers on *part_mesh* to ones in *mesh*. """ - assert len(part_per_element) == mesh.nelements + assert len(part_per_element) == mesh.nelements, "part_per_element must have shape (mesh.nelements,)" # Contains the indices of the elements requested. queried_elems = np.where(np.array(part_per_element) == part_nr)[0] @@ -106,23 +106,13 @@ def partition_mesh(mesh, part_per_element, part_nr): new_indices[group_nr][i, j] = np.where( required_indices == original_index)[0] - """ - print("mesh vertices: ", mesh.vertices) - print("mesh indices: ", mesh.groups[0].vertex_indices) - print("mesh nodes: ", mesh.groups[0].nodes) - print("queried_elems: ", queried_elems) - print("indices: ", new_indices[0]) - print("nodes: ", new_nodes[0]) - print("vertices: ", new_vertices) - """ - - from meshmode.mesh import MeshElementGroup, Mesh + from meshmode.mesh import SimplexElementGroup, Mesh new_mesh_groups = [] for group_nr in range(num_groups): mesh_group = mesh.groups[group_nr] new_mesh_groups.append( - MeshElementGroup(mesh_group.order, new_indices[group_nr], + SimplexElementGroup(mesh_group.order, new_indices[group_nr], new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes)) part_mesh = Mesh(new_vertices, new_mesh_groups) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 99a9284e291ffd661cb0b0b878e6b08ade531945..740b917b7483b6483b464eca9aa6e08dbf54f163 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -50,41 +50,46 @@ logger = logging.getLogger(__name__) # {{{ partition_mesh -#@pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"]) -#@pytest.mark.parametrize("npoints", [10, 1000]) -def test_partition_mesh(mesh_type, npoints): - from meshmode.mesh.generation import generate_torus - my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2) +@pytest.mark.parameterize("mesh_type", ["torus", "box"]) +def test_partition_mesh(mesh_type): + if mesh_type == "torus": + from meshmode.mesh.generation import generate_torus + my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2) + + part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0]) + + from meshmode.mesh.processing import partition_mesh + (part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 0) + assert part_mesh.nelements == 2 + + (part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1) + assert part_mesh.nelements == 4 + + (part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 2) + assert part_mesh.nelements == 2 + elif mesh_type == "box": + from meshmode.mesh.generation import generate_box_mesh + seg = np.linspace(0, 1, 10) + axis_coords = (seg, seg, seg) + mesh = generate_box_mesh(axis_coords) + + adjacency_list = np.zeros((mesh.nelements,), dtype=set) + for elem in range(mesh.nelements): + adjacency_list[elem] = set() + for n in range(mesh.nodal_adjacency.neighbors_starts[elem], mesh.nodal_adjacency.neighbors_starts[elem + 1]): + adjacency_list[elem].add(mesh.nodal_adjacency.neighbors[n]) + + from pymetis import part_graph + (_, part_per_element) = part_graph(3, adjacency=adjacency_list) + + from meshmode.mesh.processing import partition_mesh + (part_mesh0, _) = partition_mesh(mesh, np.array(part_per_element), 0) + (part_mesh1, _) = partition_mesh(mesh, np.array(part_per_element), 1) + (part_mesh2, _) = partition_mesh(mesh, np.array(part_per_element), 2) + + assert mesh.nelements == (part_mesh0.nelements + + part_mesh1.nelements + part_mesh2.nelements) - part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0]) - - from meshmode.mesh.processing import partition_mesh - (part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1) - - - """ - from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish - - if mesh_type == "cloverleaf": - mesh = make_curve_mesh(cloverleaf, np.linspace(0, 1, npoints), order=3) - elif mesh_type == "starfish": - mesh = make_curve_mesh(starfish, np.linspace(0, 1, npoints), order=3) - - #from meshmode.mesh.visualization import draw_2d_mesh - #draw_2d_mesh(mesh) - #import matplotlib.pyplot as pt - #pt.show() - - #TODO: Create an actuall adjacency list from the mesh. - adjacency_list = mesh.facial_adjacency_groups - print(adjacency_list) - - from pymetis import part_graph - part_per_element = np.array(part_graph(3, adjacency=adjacency_list)) - - from meshmode.mesh.processing import partition_mesh - (part_mesh, part_to_global) = partition_mesh(mesh, part_per_element, 0) - """ # }}}