From 67a2acf47400fca8905613ffbe238273c642facd Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 22 Feb 2017 15:39:24 -0600
Subject: [PATCH] Visualization for tensor product elements

---
 meshmode/discretization/visualization.py | 49 +++++++++++++++++++-----
 test/test_meshmode.py                    | 39 +++++++++++++------
 2 files changed, 68 insertions(+), 20 deletions(-)

diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py
index bfd8bc5e..660e78d3 100644
--- a/meshmode/discretization/visualization.py
+++ b/meshmode/discretization/visualization.py
@@ -124,8 +124,9 @@ class Visualizer(object):
         """
         # Assume that we're using modepy's default node ordering.
 
-        from pytools import generate_nonnegative_integer_tuples_summing_to_at_most \
-                as gnitstam, single_valued
+        from pytools import (
+                generate_nonnegative_integer_tuples_summing_to_at_most as gnitstam,
+                generate_nonnegative_integer_tuples_below as gnitb)
         from meshmode.mesh import TensorProductElementGroup, SimplexElementGroup
 
         result = []
@@ -138,9 +139,7 @@ class Visualizer(object):
 
         for group in self.vis_discr.groups:
             if isinstance(group.mesh_el_group, SimplexElementGroup):
-                vis_order = single_valued(
-                        group.order for group in self.vis_discr.groups)
-                node_tuples = list(gnitstam(vis_order, self.vis_discr.dim))
+                node_tuples = list(gnitstam(group.order, group.dim))
 
                 from modepy.tools import submesh
                 el_connectivity = np.array(
@@ -153,7 +152,33 @@ class Visualizer(object):
                         }[group.dim]
 
             elif isinstance(group.mesh_el_group, TensorProductElementGroup):
-                raise NotImplementedError()
+                node_tuples = list(gnitb(group.order+1, group.dim))
+                node_tuple_to_index = dict(
+                        (nt, i) for i, nt in enumerate(node_tuples))
+
+                def add_tuple(a, b):
+                    return tuple(ai+bi for ai, bi in zip(a, b))
+
+                el_offsets = {
+                        1: [(0,), (1,)],
+                        2: [(0, 0), (1, 0), (1, 1), (0, 1)],
+                        3: [
+                            (0, 0, 0),
+                            (1, 0, 0),
+                            (1, 1, 0),
+                            (0, 1, 0),
+                            (0, 0, 1),
+                            (1, 0, 1),
+                            (1, 1, 1),
+                            (0, 1, 1),
+                            ]
+                        }[group.dim]
+
+                el_connectivity = np.array([
+                        [
+                            node_tuple_to_index[add_tuple(origin, offset)]
+                            for offset in el_offsets]
+                        for origin in gnitb(group.order, group.dim)])
 
                 vtk_cell_type = {
                         1: VTK_LINE,
@@ -291,11 +316,17 @@ class Visualizer(object):
 
 def make_visualizer(queue, discr, vis_order):
     from meshmode.discretization import Discretization
-    from meshmode.discretization.poly_element import \
-            PolynomialWarpAndBlendGroupFactory
+    from meshmode.discretization.poly_element import (
+            PolynomialWarpAndBlendElementGroup,
+            LegendreGaussLobattoTensorProductElementGroup,
+            OrderAndTypeBasedGroupFactory)
     vis_discr = Discretization(
             discr.cl_context, discr.mesh,
-            PolynomialWarpAndBlendGroupFactory(vis_order),
+            OrderAndTypeBasedGroupFactory(
+                vis_order,
+                simplex_group_class=PolynomialWarpAndBlendElementGroup,
+                tensor_product_group_class=(
+                    LegendreGaussLobattoTensorProductElementGroup)),
             real_dtype=discr.real_dtype)
     from meshmode.discretization.connection import \
             make_same_mesh_connection
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 20ee2b6d..38b153e1 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -454,32 +454,49 @@ def test_element_orientation():
 
 def test_merge_and_map(ctx_getter, visualize=False):
     from meshmode.mesh.io import generate_gmsh, FileSource
+    from meshmode.mesh.generation import generate_box_mesh
+    from meshmode.mesh import TensorProductElementGroup
+    from meshmode.discretization.poly_element import (
+            PolynomialWarpAndBlendGroupFactory,
+            LegendreGaussLobattoTensorProductGroupFactory)
 
     mesh_order = 3
 
-    mesh = generate_gmsh(
-            FileSource("blob-2d.step"), 2, order=mesh_order,
-            force_ambient_dim=2,
-            other_options=["-string", "Mesh.CharacteristicLengthMax = 0.02;"]
-            )
+    if 1:
+        mesh = generate_gmsh(
+                FileSource("blob-2d.step"), 2, order=mesh_order,
+                force_ambient_dim=2,
+                other_options=["-string", "Mesh.CharacteristicLengthMax = 0.02;"]
+                )
+
+        discr_grp_factory = PolynomialWarpAndBlendGroupFactory(3)
+    else:
+        mesh = generate_box_mesh(
+                (
+                    np.linspace(0, 1, 4),
+                    np.linspace(0, 1, 4),
+                    np.linspace(0, 1, 4),
+                    ),
+                10, group_factory=TensorProductElementGroup)
+
+        discr_grp_factory = LegendreGaussLobattoTensorProductGroupFactory(3)
 
     from meshmode.mesh.processing import merge_disjoint_meshes, affine_map
-    mesh2 = affine_map(mesh, A=np.eye(2), b=np.array([5, 0]))
+    mesh2 = affine_map(mesh,
+            A=np.eye(mesh.ambient_dim),
+            b=np.array([5, 0, 0])[:mesh.ambient_dim])
 
     mesh3 = merge_disjoint_meshes((mesh2, mesh))
 
     if visualize:
         from meshmode.discretization import Discretization
-        from meshmode.discretization.poly_element import \
-                PolynomialWarpAndBlendGroupFactory
         cl_ctx = ctx_getter()
         queue = cl.CommandQueue(cl_ctx)
 
-        discr = Discretization(cl_ctx, mesh3,
-                PolynomialWarpAndBlendGroupFactory(3))
+        discr = Discretization(cl_ctx, mesh3, discr_grp_factory)
 
         from meshmode.discretization.visualization import make_visualizer
-        vis = make_visualizer(queue, discr, 1)
+        vis = make_visualizer(queue, discr, 3)
         vis.write_vtk_file("merged.vtu", [])
 
 # }}}
-- 
GitLab