From 68c870eec9e622a41e4579c92bc287acfac6eaa5 Mon Sep 17 00:00:00 2001
From: ellis <eshoag2@illinois.edu>
Date: Tue, 14 Feb 2017 12:20:53 -0600
Subject: [PATCH] partition_mesh works with multiple MeshElementGroups

---
 meshmode/mesh/processing.py | 45 +++++++++++++++++++++++--------------
 test/test_meshmode.py       | 23 ++++++++++---------
 2 files changed, 41 insertions(+), 27 deletions(-)

diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 00b7a8bf..28d146b7 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -66,30 +66,39 @@ def partition_mesh(mesh, part_per_element, part_nr):
 
     # The set of vertex indices we need.
     index_set = set()
-
+    skip_groups = []
+    num_prev_elems = 0
     start_idx = 0
     for group_nr in range(num_groups):
         mesh_group = mesh.groups[group_nr]
 
         # Find the index of first element in the next group
         end_idx = len(queried_elems)
-        for idx in range(len(queried_elems)):
-            if queried_elems[idx] - start_idx > mesh_group.nelements:
+        for idx in range(start_idx, len(queried_elems)):
+            if queried_elems[idx] - num_prev_elems >= mesh_group.nelements:
                 end_idx = idx
                 break
 
-        elems = queried_elems[start_idx:end_idx]
+        if start_idx == end_idx:
+            skip_groups.append(group_nr)
+            new_indices.append(np.array([]))
+            new_nodes.append(np.array([]))
+            num_prev_elems += mesh_group.nelements
+            continue
+
+        elems = queried_elems[start_idx:end_idx] - num_prev_elems
         new_indices.append(mesh_group.vertex_indices[elems])
 
         new_nodes.append(
-            np.zeros((mesh.ambient_dim, len(queried_elems), mesh_group.nunit_nodes)))
+            np.zeros((mesh.ambient_dim, end_idx - start_idx, mesh_group.nunit_nodes)))
         for i in range(mesh.ambient_dim):
             for j in range(start_idx, end_idx):
-                elems = queried_elems[j]
-                new_nodes[group_nr][i, j, :] = mesh_group.nodes[i, elems, :]
+                elems = queried_elems[j] - num_prev_elems
+                new_nodes[group_nr][i, j - start_idx, :] = mesh_group.nodes[i, elems, :]
 
         index_set |= set(new_indices[group_nr].ravel())
 
+        num_prev_elems += mesh_group.nelements
         start_idx = end_idx
 
     # A sorted np.array of vertex indices we need (without duplicates).
@@ -99,22 +108,24 @@ def partition_mesh(mesh, part_per_element, part_nr):
     for dim in range(mesh.ambient_dim):
         new_vertices[dim] = mesh.vertices[dim][required_indices]
 
-    # Our indices need to be in range [0, len(mesh_group.nelements)].
+    # Our indices need to be in range [0, len(mesh.nelements)].
     for group_nr in range(num_groups):
-        for i in range(len(new_indices[group_nr])):
-            for j in range(len(new_indices[group_nr][0])):
-                original_index = new_indices[group_nr][i, j]
-                new_indices[group_nr][i, j] = np.where(
-                    required_indices == original_index)[0]
+        if not group_nr in skip_groups:
+            for i in range(len(new_indices[group_nr])):
+                for j in range(len(new_indices[group_nr][0])):
+                    original_index = new_indices[group_nr][i, j]
+                    new_indices[group_nr][i, j] = np.where(
+                        required_indices == original_index)[0]
 
     from meshmode.mesh import SimplexElementGroup, Mesh
 
     new_mesh_groups = []
     for group_nr in range(num_groups):
-        mesh_group = mesh.groups[group_nr]
-        new_mesh_groups.append(
-            SimplexElementGroup(mesh_group.order, new_indices[group_nr],
-                new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
+        if not group_nr in skip_groups:
+            mesh_group = mesh.groups[group_nr]
+            new_mesh_groups.append(
+                SimplexElementGroup(mesh_group.order, new_indices[group_nr],
+                    new_nodes[group_nr], unit_nodes=mesh_group.unit_nodes))
 
     part_mesh = Mesh(new_vertices, new_mesh_groups)
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 4cdffd85..30b5e001 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -50,7 +50,7 @@ logger = logging.getLogger(__name__)
 
 # {{{ partition_mesh
 
-@pytest.mark.parametrize("mesh_type", ["torus", "box"])
+@pytest.mark.parametrize("mesh_type", ["torus", "boxes"])
 def test_partition_mesh(mesh_type):
     if mesh_type == "torus":
         from meshmode.mesh.generation import generate_torus
@@ -67,11 +67,13 @@ def test_partition_mesh(mesh_type):
         assert part_mesh1.nelements == 4
         assert part_mesh2.nelements == 2
 
-    elif mesh_type == "box":
-        from meshmode.mesh.generation import generate_box_mesh
-        seg = np.linspace(0, 1, 10)
-        axis_coords = (seg, seg, seg)
-        mesh = generate_box_mesh(axis_coords)
+    elif mesh_type == "boxes":
+        from meshmode.mesh.generation import generate_regular_rect_mesh
+        mesh1 = generate_regular_rect_mesh(a=(0, 0, 0), b=(1, 1, 1), n=(5, 5, 5))
+        mesh2 = generate_regular_rect_mesh(a=(2, 2, 2), b=(3, 3, 3), n=(5, 5, 5))
+
+        from meshmode.mesh.processing import merge_disjoint_meshes
+        mesh = merge_disjoint_meshes([mesh1, mesh2])
 
         adjacency_list = np.zeros((mesh.nelements,), dtype=set)
         for elem in range(mesh.nelements):
@@ -81,12 +83,13 @@ def test_partition_mesh(mesh_type):
                 adjacency_list[elem].add(mesh.nodal_adjacency.neighbors[n])
 
         from pymetis import part_graph
-        (_, part_per_element) = part_graph(3, adjacency=adjacency_list)
+        (_, p) = part_graph(3, adjacency=adjacency_list)
+        part_per_element = np.array(p)
 
         from meshmode.mesh.processing import partition_mesh
-        (part_mesh0, _) = partition_mesh(mesh, np.array(part_per_element), 0)
-        (part_mesh1, _) = partition_mesh(mesh, np.array(part_per_element), 1)
-        (part_mesh2, _) = partition_mesh(mesh, np.array(part_per_element), 2)
+        (part_mesh0, _) = partition_mesh(mesh, part_per_element, 0)
+        (part_mesh1, _) = partition_mesh(mesh, part_per_element, 1)
+        (part_mesh2, _) = partition_mesh(mesh, part_per_element, 2)
 
         assert mesh.nelements == (part_mesh0.nelements
             + part_mesh1.nelements + part_mesh2.nelements)
-- 
GitLab