diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index 97b9bbda9c99fd0dc1f537e6cbbbaf043bc6140c..81d8526b977d7883ab75ba89da71a64b9558e9f8 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -569,6 +569,51 @@ def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False): # }}} +# {{{ split meshes + + +def split_mesh_groups(mesh, element_flags): + """Split all the groups in *mesh* in two according to the values of + *element_flags*. + + :arg element_flags: a :class:`numpy.ndarray` with + :attr:`~meshmode.mesh.Mesh.nelements` entries + indicating by their *Boolean* value how the elements in a group + are to be split. + + :returns: a :class:`~meshmode.mesh.Mesh` where each group has been split + according to flags in *element_flags*. + """ + assert element_flags.shape == (mesh.nelements,) + + new_groups = [] + for grp in mesh.groups: + mask = element_flags[ + grp.element_nr_base:grp.element_nr_base + grp.nelements] + + vertex_indices = grp.vertex_indices[mask, :].copy() + if vertex_indices.size > 0: + new_groups.append(grp.copy( + vertex_indices=vertex_indices, + nodes=grp.nodes[:, mask, :].copy() + )) + + vertex_indices = grp.vertex_indices[~mask, :].copy() + if vertex_indices.size > 0: + new_groups.append(grp.copy( + vertex_indices=vertex_indices, + nodes=grp.nodes[:, ~mask, :].copy() + )) + + from meshmode.mesh import Mesh + return Mesh( + vertices=mesh.vertices, + groups=new_groups, + is_conforming=mesh.is_conforming) + +# }}} + + # {{{ map def map_mesh(mesh, f): # noqa diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 06b7bcfab7373837d45b61fcc7d929e1746b4333..b6a564abb7647cc20db52054b0a567f19a341f0a 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1304,11 +1304,19 @@ def test_mesh_multiple_groups(ctx_factory, ambient_dim, visualize=False): mesh = generate_regular_rect_mesh( a=(-0.5,)*ambient_dim, b=(0.5,)*ambient_dim, n=(8,)*ambient_dim, order=order) - mesh = split_mesh(mesh, axis=0, cutoff=0.0) + assert len(mesh.groups) == 1 + + from meshmode.mesh.processing import split_mesh_groups + element_flags = np.any( + mesh.vertices[0, mesh.groups[0].vertex_indices] < 0.0, + axis=1) + mesh = split_mesh_groups(mesh, element_flags) + + assert len(mesh.groups) == 2 assert mesh.facial_adjacency_groups assert mesh.nodal_adjacency - if visualize: + if visualize and ambient_dim == 2: from meshmode.mesh.visualization import draw_2d_mesh draw_2d_mesh(mesh, draw_element_numbers=True, draw_face_numbers=True, set_bounding_box=True)