diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index ff74ac81abdc6e414bd5296b27e8545a047acd25..27d3ef6e39b9e8fc912ca0109332706bb215b66e 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -39,6 +39,57 @@ __doc__ = """ .. autofunction:: affine_map """ +def partition_mesh(mesh, part_per_element, part_nr): + """ + :arg mesh: A :class:`meshmode.mesh.Mesh` to be partitioned. + :arg part_per_element: A :class:`numpy.ndarray` containing one + integer per element of *mesh* indicating which part of the + partitioned mesh the element is to become a part of. + :arg part_nr: The part number of the mesh to return. + + :returns: A tuple ``(part_mesh, part_to_global)``, where *part_mesh* + is a :class:`meshmode.mesh.Mesh` that is a partition of mesh, and + *part_to_global* is a :class:`numpy.ndarray` mapping element + numbers on *part_mesh* to ones in *mesh*. + """ + + queried_elems = np.where(part_per_element == part_nr)[0] + + ''' + Here I assume that all elements are taken from the 0th group. + Will there be more groups? How do I handle this? + My guess is that len(part_per_element) will be equal to the sum of the + number of elements in each group. And that the first element of groups[1] + comes just after the last element in groups[0]. + ''' + mesh_group = mesh.groups[0] + + new_vertex_indices = mesh_group.vertex_indices[queried_elems] + + # A sorted np.array of vertex indices we need in our new mesh (without duplicates). + required_vertex_indices = np.array(list(set(new_vertex_indices.ravel()))) + + new_vertices = np.zeros((mesh.ambient_dim, len(required_vertex_indices))) + for dim in range(mesh.ambient_dim): + new_vertices[dim] = mesh.vertices[dim][required_vertex_indices] + + # We need to update our indices to be in range [0, len(required_vertex_indices)]. + for i in range(len(new_vertex_indices)): + for j in range(len(new_vertex_indices[0])): + new_vertex_indices[i, j] = np.where(required_vertex_indices == new_vertex_indices[i, j])[0] + + new_nodes = np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes)) + for i in range(mesh.ambient_dim): + for j in range(len(queried_elems)): + new_nodes[i, j, :] = mesh_group.nodes[i, queried_elems[j], :] + + from meshmode.mesh import MeshElementGroup, Mesh + + mesh_element_group = MeshElementGroup(mesh_group.order, new_vertex_indices, new_nodes, element_nr_base=mesh_group.element_nr_base, node_nr_base=mesh_group.node_nr_base, unit_nodes=mesh_group.unit_nodes, dim=mesh_group.dim) + + part_mesh = Mesh(new_vertices, [mesh_element_group]) + + return (part_mesh, queried_elems) # {{{ orientations