From 5d6dd963a44747e1985eeb7396e1fe6aa31ff445 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Thu, 28 May 2020 12:00:24 -0500
Subject: [PATCH] add a test with a split mesh with multiple groups

---
 test/test_meshmode.py | 80 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 80 insertions(+)

diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 77a25cd5..d8cede25 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -1272,6 +1272,86 @@ def test_is_affine_group_check(mesh_name):
     assert all(grp.is_affine for grp in mesh.groups) == is_affine
 
 
+def split_mesh(mesh, axis, cutoff):
+    groups = []
+    for grp in mesh.groups:
+        mask = np.any(mesh.vertices[axis, grp.vertex_indices] < cutoff, axis=1)
+
+        groups.append(grp.copy(
+                vertex_indices=grp.vertex_indices[mask, :].copy(),
+                nodes=grp.nodes[:, mask, :].copy())
+                )
+        groups.append(grp.copy(
+                vertex_indices=grp.vertex_indices[~mask, :].copy(),
+                nodes=grp.nodes[:, ~mask, :].copy())
+                )
+
+    from meshmode.mesh import Mesh
+    return Mesh(
+            vertices=mesh.vertices,
+            groups=groups,
+            is_conforming=mesh.is_conforming)
+
+
+@pytest.mark.parametrize("ambient_dim", [1, 2, 3])
+def test_mesh_multiple_groups(ctx_factory, ambient_dim, visualize=False):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    order = 4
+
+    from meshmode.mesh.generation import generate_regular_rect_mesh
+    mesh = generate_regular_rect_mesh(
+            a=(-0.5,)*ambient_dim, b=(0.5,)*ambient_dim,
+            n=(8,)*ambient_dim, order=order)
+    mesh = split_mesh(mesh, axis=0, cutoff=0.0)
+    assert mesh.facial_adjacency_groups
+    assert mesh.nodal_adjacency
+
+    from meshmode.discretization import Discretization
+    from meshmode.discretization.poly_element import \
+            PolynomialWarpAndBlendGroupFactory as GroupFactory
+    discr = Discretization(ctx, mesh, GroupFactory(order))
+
+    if visualize:
+        group_id = discr.empty(queue, dtype=np.int)
+        for igrp, grp in enumerate(discr.groups):
+            group_id_view = grp.view(group_id)
+            group_id_view.fill(igrp)
+
+        from meshmode.discretization.visualization import make_visualizer
+        vis = make_visualizer(queue, discr, vis_order=order)
+        vis.write_vtk_file("test_mesh_multiple_groups.vtu", [
+            ("group_id", group_id)
+            ], overwrite=True)
+
+    # check face restrictions
+    from meshmode.discretization.connection import (
+            make_face_restriction,
+            make_face_to_all_faces_embedding,
+            make_opposite_face_connection,
+            check_connection)
+    for boundary_tag in [BTAG_ALL, FACE_RESTR_INTERIOR, FACE_RESTR_ALL]:
+        conn = make_face_restriction(discr, GroupFactory(order),
+                boundary_tag=boundary_tag,
+                per_face_groups=False)
+        check_connection(conn)
+
+        bdry_f = conn.to_discr.empty(queue)
+        bdry_f.fill(1.0)
+
+        if boundary_tag == FACE_RESTR_INTERIOR:
+            opposite = make_opposite_face_connection(conn)
+            check_connection(opposite)
+
+            op_bdry_f = opposite(queue, bdry_f)
+            assert abs(cl.array.sum(bdry_f - op_bdry_f).get(queue)) < 1.0e-14
+
+        if boundary_tag == FACE_RESTR_ALL:
+            embedding = make_face_to_all_faces_embedding(conn, conn.to_discr)
+            check_connection(embedding)
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab