From a786dae9b08801ebd4b93f6ba4f521cb60a85bb5 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 14 Dec 2016 15:23:53 -0600
Subject: [PATCH] Add single-group mode for merge_disjoint_meshes

---
 meshmode/mesh/processing.py | 56 +++++++++++++++++++++++++++++++------
 1 file changed, 47 insertions(+), 9 deletions(-)

diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py
index 06631f9d..ff74ac81 100644
--- a/meshmode/mesh/processing.py
+++ b/meshmode/mesh/processing.py
@@ -227,7 +227,7 @@ def find_bounding_box(mesh):
 
 # {{{ merging
 
-def merge_disjoint_meshes(meshes, skip_tests=False):
+def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False):
     if not meshes:
         raise ValueError("must pass at least one mesh")
 
@@ -262,18 +262,56 @@ def merge_disjoint_meshes(meshes, skip_tests=False):
 
     # {{{ assemble new groups list
 
-    new_groups = []
-
-    for mesh, vert_base in zip(meshes, vert_bases):
-        for group in mesh.groups:
-            new_vertex_indices = group.vertex_indices + vert_base
-            new_group = group.copy(vertex_indices=new_vertex_indices)
-            new_groups.append(new_group)
+    if single_group:
+        grp_cls = None
+        order = None
+        unit_nodes = None
+        nodal_adjacency = None
+
+        for mesh in meshes:
+            if mesh._nodal_adjacency is not None:
+                nodal_adjacency = False
+
+            for group in mesh.groups:
+                if grp_cls is None:
+                    grp_cls = type(group)
+                    order = group.order
+                    unit_nodes = group.unit_nodes
+                else:
+                    assert type(group) == grp_cls
+                    assert group.order == order
+                    assert np.array_equal(unit_nodes, group.unit_nodes)
+
+        vertex_indices = np.vstack([
+            group.vertex_indices + vert_base
+            for mesh, vert_base in zip(meshes, vert_bases)
+            for group in mesh.groups])
+        nodes = np.hstack([
+            group.nodes
+            for mesh in meshes
+            for group in mesh.groups])
+
+        new_groups = [
+                grp_cls(order, vertex_indices, nodes, unit_nodes=unit_nodes)]
+
+    else:
+        new_groups = []
+        nodal_adjacency = None
+
+        for mesh, vert_base in zip(meshes, vert_bases):
+            if mesh._nodal_adjacency is not None:
+                nodal_adjacency = False
+
+            for group in mesh.groups:
+                new_vertex_indices = group.vertex_indices + vert_base
+                new_group = group.copy(vertex_indices=new_vertex_indices)
+                new_groups.append(new_group)
 
     # }}}
 
     from meshmode.mesh import Mesh
-    return Mesh(vertices, new_groups, skip_tests=skip_tests)
+    return Mesh(vertices, new_groups, skip_tests=skip_tests,
+            nodal_adjacency=nodal_adjacency)
 
 # }}}
 
-- 
GitLab