From 488b613bc13c3ee0b3ecaae3e01c2170ca005f20 Mon Sep 17 00:00:00 2001
From: Matt Smith <mjsmith6@illinois.edu>
Date: Sat, 9 Apr 2022 20:30:58 -0500
Subject: [PATCH] Optionally return volume-tag-to-element map from `read_gmsh` 
 (#320)

* add tag list to testmesh.msh

* add option to return volume-tag-to-element map in read_gmsh

* also add return_volume_to_elements_map argument to generate_gmsh

* fix pylint errors?

or maybe just confuse it sufficiently...
---
 meshmode/mesh/io.py        | 56 ++++++++++++++++++++++++++++++--------
 test/test_mesh.py          | 20 ++++++++++++++
 test/testmesh.msh          |  8 ++++--
 test/testmesh_multivol.msh | 20 ++++++++++++++
 4 files changed, 90 insertions(+), 14 deletions(-)
 create mode 100644 test/testmesh_multivol.msh

diff --git a/meshmode/mesh/io.py b/meshmode/mesh/io.py
index 4b06cd6a..1f03dc0e 100644
--- a/meshmode/mesh/io.py
+++ b/meshmode/mesh/io.py
@@ -51,7 +51,6 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
         # Use data fields similar to meshpy.triangle.MeshInfo and
         # meshpy.tet.MeshInfo
         self.points = None
-        self.elements = None
         self.element_vertices = None
         self.element_nodes = None
         self.element_types = None
@@ -117,7 +116,7 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
     def finalize_tags(self):
         pass
 
-    def get_mesh(self):
+    def get_mesh(self, return_tag_to_elements_map=False):
         el_type_hist = {}
         for el_type in self.element_types:
             el_type_hist[el_type] = el_type_hist.get(el_type, 0) + 1
@@ -141,8 +140,10 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
                     vertex_gmsh_index_to_mine[gmsh_vertex_nr] = \
                             len(vertex_gmsh_index_to_mine)
             if self.tags:
-                el_tag_indexes = [self.gmsh_tag_index_to_mine[t] for t in
-                                  self.element_markers[element]]
+                el_markers = self.element_markers[element]
+                el_tag_indexes = (
+                    [self.gmsh_tag_index_to_mine[t] for t in el_markers]
+                    if el_markers is not None else [])
                 # record tags of boundary dimension
                 el_tags = [self.tags[i][0] for i in el_tag_indexes if
                            self.tags[i][1] == mesh_bulk_dim - 1]
@@ -170,6 +171,10 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
 
         bulk_el_types = set()
 
+        group_base_elem_nr = 0
+
+        tag_to_elements = {}
+
         for group_el_type, ngroup_elements in el_type_hist.items():
             if group_el_type.dimensions != mesh_bulk_dim:
                 continue
@@ -185,8 +190,9 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
                     np.int32)
             i = 0
 
-            for el_vertices, el_nodes, el_type in zip(
-                    self.element_vertices, self.element_nodes, self.element_types):
+            for el_vertices, el_nodes, el_type, el_markers in zip(
+                    self.element_vertices, self.element_nodes, self.element_types,
+                    self.element_markers):
                 if el_type is not group_el_type:
                     continue
 
@@ -195,6 +201,14 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
                         vertex_gmsh_index_to_mine[v_nr] for v_nr in el_vertices
                         ]
 
+                if el_markers is not None:
+                    for t in el_markers:
+                        tag = self.tags[self.gmsh_tag_index_to_mine[t]][0]
+                        if tag not in tag_to_elements:
+                            tag_to_elements[tag] = [group_base_elem_nr + i]
+                        else:
+                            tag_to_elements[tag].append(group_base_elem_nr + i)
+
                 i += 1
 
             import modepy as mp
@@ -238,6 +252,11 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
 
             groups.append(group)
 
+            group_base_elem_nr += group.nelements
+
+        for tag in tag_to_elements.keys():
+            tag_to_elements[tag] = np.array(tag_to_elements[tag], dtype=np.int32)
+
         # FIXME: This is heuristic.
         if len(bulk_el_types) == 1:
             is_conforming = True
@@ -251,18 +270,22 @@ class GmshMeshReceiver(GmshMeshReceiverBase):
             facial_adjacency_groups = _compute_facial_adjacency_from_vertices(
                     groups, np.int32, np.int8, face_vertex_indices_to_tags)
 
-        return Mesh(
+        mesh = Mesh(
                 vertices, groups,
                 is_conforming=is_conforming,
                 facial_adjacency_groups=facial_adjacency_groups,
                 **self.mesh_construction_kwargs)
 
+        return (mesh, tag_to_elements) if return_tag_to_elements_map else mesh
+
 # }}}
 
 
 # {{{ gmsh
 
-def read_gmsh(filename, force_ambient_dim=None, mesh_construction_kwargs=None):
+def read_gmsh(
+        filename, force_ambient_dim=None, mesh_construction_kwargs=None,
+        return_tag_to_elements_map=False):
     """Read a gmsh mesh file from *filename* and return a
     :class:`meshmode.mesh.Mesh`.
 
@@ -270,18 +293,22 @@ def read_gmsh(filename, force_ambient_dim=None, mesh_construction_kwargs=None):
         this many dimensions.
     :arg mesh_construction_kwargs: *None* or a dictionary of keyword
         arguments passed to the :class:`meshmode.mesh.Mesh` constructor.
+    :arg return_tag_to_elements_map: If *True*, return in addition to the mesh
+        a :class:`dict` that maps each volume tag in the gmsh file to a
+        :class:`numpy.ndarray` containing meshwide indices of the elements that
+        belong to that volume.
     """
     from gmsh_interop.reader import read_gmsh
     recv = GmshMeshReceiver(mesh_construction_kwargs=mesh_construction_kwargs)
     read_gmsh(recv, filename, force_dimension=force_ambient_dim)
 
-    return recv.get_mesh()
+    return recv.get_mesh(return_tag_to_elements_map=return_tag_to_elements_map)
 
 
 def generate_gmsh(source, dimensions=None, order=None, other_options=None,
         extension="geo", gmsh_executable="gmsh", force_ambient_dim=None,
         output_file_name="output.msh", mesh_construction_kwargs=None,
-        target_unit=None):
+        target_unit=None, return_tag_to_elements_map=False):
     """Run :command:`gmsh` on the input given by *source*, and return a
     :class:`meshmode.mesh.Mesh` based on the result.
 
@@ -316,9 +343,14 @@ def generate_gmsh(source, dimensions=None, order=None, other_options=None,
         parse_gmsh(recv, runner.output_file,
                 force_dimension=force_ambient_dim)
 
-    mesh = recv.get_mesh()
+    result = recv.get_mesh(return_tag_to_elements_map=return_tag_to_elements_map)
 
     if force_ambient_dim is None:
+        if return_tag_to_elements_map:
+            mesh = result[0]
+        else:
+            mesh = result
+
         AXIS_NAMES = "xyz"  # noqa
 
         dim = mesh.vertices.shape[0]
@@ -331,7 +363,7 @@ def generate_gmsh(source, dimensions=None, order=None, other_options=None,
                             AXIS_NAMES[idim], idim))
                 break
 
-    return mesh
+    return result
 
 # }}}
 
diff --git a/test/test_mesh.py b/test/test_mesh.py
index cbcde031..161bf900 100644
--- a/test/test_mesh.py
+++ b/test/test_mesh.py
@@ -735,6 +735,26 @@ def test_boundary_tags():
 # }}}
 
 
+# {{{ test volume tags
+
+def test_volume_tags():
+    from meshmode.mesh.io import read_gmsh
+    mesh, tag_to_elements_map = read_gmsh(
+        "testmesh_multivol.msh", return_tag_to_elements_map=True)
+
+    assert len(tag_to_elements_map) == 2
+
+    assert "Vol1" in tag_to_elements_map
+    assert "Vol2" in tag_to_elements_map
+
+    assert isinstance(tag_to_elements_map["Vol1"], np.ndarray)
+
+    assert np.all(tag_to_elements_map["Vol1"] == np.array([0]))
+    assert np.all(tag_to_elements_map["Vol2"] == np.array([1]))
+
+# }}}
+
+
 # {{{ test custom boundary tags on box mesh
 
 @pytest.mark.parametrize(("dim", "nelem", "mesh_type"), [
diff --git a/test/testmesh.msh b/test/testmesh.msh
index 07de5a57..14f5be95 100644
--- a/test/testmesh.msh
+++ b/test/testmesh.msh
@@ -1,6 +1,10 @@
 $MeshFormat
 2.2 0 8
 $EndMeshFormat
+$PhysicalNames
+1
+2 1 "Volume"
+$EndPhysicalNames
 $Nodes
 4
 1 0 0 0
@@ -10,6 +14,6 @@ $Nodes
 $EndNodes
 $Elements
 2
-1 2 2 3 1 1 2 3
-2 2 2 3 1 2 3 4
+1 2 2 1 1 1 2 3
+2 2 2 1 1 2 3 4
 $EndElements
diff --git a/test/testmesh_multivol.msh b/test/testmesh_multivol.msh
new file mode 100644
index 00000000..9b5e7ff3
--- /dev/null
+++ b/test/testmesh_multivol.msh
@@ -0,0 +1,20 @@
+$MeshFormat
+2.2 0 8
+$EndMeshFormat
+$PhysicalNames
+2
+2 1 "Vol1"
+2 2 "Vol2"
+$EndPhysicalNames
+$Nodes
+4
+1 0 0 0
+2 1 0 0
+3 0 1 0
+4 1 1 0
+$EndNodes
+$Elements
+2
+1 2 2 1 1 1 2 3
+2 2 2 2 1 2 3 4
+$EndElements
-- 
GitLab