From 90346bedeec0398a1952a511a6958b7e4f26b1cf Mon Sep 17 00:00:00 2001
From: Ellis <eshoag2@illinois.edu>
Date: Fri, 10 Feb 2017 20:01:56 -0600
Subject: [PATCH] partition mesh almost supports multiple element groups

---
 meshmode/mesh/processing.py | 64 +++++++++++++++++++++++--------------
 test/test_meshmode.py       | 17 ++++++++--
 2 files changed, 54 insertions(+), 27 deletions(-)

diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 27d3ef6e..7f520f44 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -31,6 +31,7 @@ import modepy as mp
 
 
 __doc__ = """
+.. autofunction:: partition_mesh
 .. autofunction:: find_volume_mesh_element_orientations
 .. autofunction:: perform_flips
 .. autofunction:: find_bounding_box
@@ -39,6 +40,7 @@ __doc__ = """
 .. autofunction:: affine_map
 """
 
+
 def partition_mesh(mesh, part_per_element, part_nr):
     """
     :arg mesh: A :class:`meshmode.mesh.Mesh` to be partitioned.
@@ -52,45 +54,59 @@ def partition_mesh(mesh, part_per_element, part_nr):
         *part_to_global* is a :class:`numpy.ndarray` mapping element
         numbers on *part_mesh* to ones in *mesh*.
     """
+    assert len(part_per_element) == mesh.nelements
+
+    queried_elems = np.where(np.array(part_per_element) == part_nr)[0]
+
+    num_groups = len(mesh.groups)
+    new_indices = []
+    new_nodes = []
+
+    # The set of vertex indices we need.
+    index_set = set()
 
-    queried_elems = np.where(part_per_element == part_nr)[0]
+    start_idx = 0
+    for group_nr in range(num_groups):
+        mesh_group = mesh.groups[group_nr]
+        #end_idx = start_idx + np.where(queried_elems >= mesh_group.nelements - start_idx)[0] - 1
+        end_idx = len(queried_elems)
 
-    '''
-    Here I assume that all elements are taken from the 0th group.
-    Will there be more groups? How do I handle this?
-    My guess is that len(part_per_element) will be equal to the sum of the
-    number of elements in each group. And that the first element of groups[1]
-    comes just after the last element in groups[0].
-    '''
-    mesh_group = mesh.groups[0]
+        new_indices.append(mesh_group.vertex_indices[queried_elems[start_idx:end_idx]])
 
-    new_vertex_indices = mesh_group.vertex_indices[queried_elems]
+        new_nodes.append(np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes)))
+        for i in range(mesh.ambient_dim):
+            for j in range(start_idx, end_idx):
+                new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, queried_elems[j], :]
+
+        index_set |= set(new_indices[group_nr].ravel())
+
+        start_idx = end_idx
 
     # A sorted np.array of vertex indices we need in our new mesh (without duplicates).
-    required_vertex_indices = np.array(list(set(new_vertex_indices.ravel())))
+    required_indices = np.array(list(index_set))
 
-    new_vertices = np.zeros((mesh.ambient_dim, len(required_vertex_indices)))
+    new_vertices = np.zeros((mesh.ambient_dim, len(required_indices)))
     for dim in range(mesh.ambient_dim):
-        new_vertices[dim] = mesh.vertices[dim][required_vertex_indices]
+        new_vertices[dim] = mesh.vertices[dim][required_indices]
 
-    # We need to update our indices to be in range [0, len(required_vertex_indices)].
-    for i in range(len(new_vertex_indices)):
-        for j in range(len(new_vertex_indices[0])):
-            new_vertex_indices[i, j] = np.where(required_vertex_indices == new_vertex_indices[i, j])[0]
-
-    new_nodes = np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes))
-    for i in range(mesh.ambient_dim):
-        for j in range(len(queried_elems)):
-            new_nodes[i, j, :] = mesh_group.nodes[i, queried_elems[j], :]
+    # We need to update our indices to be in range [0, len(mesh_group.nelements)].
+    for group_nr in range(num_groups):
+        for i in range(len(new_indices[group_nr])):
+            for j in range(len(new_indices[group_nr][0])):
+                new_indices[group_nr][i, j] = np.where(required_indices == new_indices[group_nr][i, j])[0]
 
     from meshmode.mesh import MeshElementGroup, Mesh
 
-    mesh_element_group = MeshElementGroup(mesh_group.order, new_vertex_indices, new_nodes, element_nr_base=mesh_group.element_nr_base, node_nr_base=mesh_group.node_nr_base, unit_nodes=mesh_group.unit_nodes, dim=mesh_group.dim)
+    new_mesh_groups = []
+    for group_nr in range(num_groups):
+        mesh_group = mesh.groups[group_nr]
+        new_mesh_groups.append(MeshElementGroup(mesh_group.order, new_indices[group_nr], new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes, dim=mesh_group.dim))
 
-    part_mesh = Mesh(new_vertices, [mesh_element_group])
+    part_mesh = Mesh(new_vertices, new_mesh_groups)
 
     return (part_mesh, queried_elems)
 
+
 # {{{ orientations
 
 def find_volume_mesh_element_group_orientation(vertices, grp):
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index eeda06b5..b5b2cde3 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -51,8 +51,19 @@ logger = logging.getLogger(__name__)
 
 @pytest.mark.parametrize("mesh_type", ["cloverleaf", "starfish"])
 @pytest.mark.parametrize("npoints", [10, 1000])
-def test_partition_mesh(mesh_type, npoints):
-    from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish
+def test_partition_mesh(mesh_type=None, npoints=None):
+    from meshmode.mesh.generation import generate_torus
+
+    my_mesh = generate_torus(2, 1, n_outer=2, n_inner=2)
+
+    part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0])
+
+    print(my_mesh.nelements, my_mesh.groups[0].nelements)
+
+    #(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1)
+    from meshmode.mesh.processing import partition_mesh
+    partition_mesh(my_mesh, part_per_element, 1)
+    """from meshmode.mesh.generation import make_curve_mesh, cloverleaf, starfish
 
     if mesh_type == "cloverleaf":
         mesh = make_curve_mesh(cloverleaf, np.linspace(0, 1, npoints), order=3)
@@ -67,7 +78,7 @@ def test_partition_mesh(mesh_type, npoints):
 
     from meshmode.mesh.processing import partition_mesh
     (part_mesh, part_to_global) = partition_mesh(mesh, part_per_element, 0)
-
+"""
 # }}}
 
 
-- 
GitLab