From 69217f169c4e9c6f4631997b91b88cb477af07a9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 22 Feb 2017 12:55:36 -0600
Subject: [PATCH] Visualizer: add proper support for multiple inhomogeneous
 groups

---
 meshmode/discretization/visualization.py | 141 ++++++++++++++++-------
 1 file changed, 100 insertions(+), 41 deletions(-)

diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py
index d68ec0c..bfd8bc5 100644
--- a/meshmode/discretization/visualization.py
+++ b/meshmode/discretization/visualization.py
@@ -1,6 +1,4 @@
-from __future__ import division
-from __future__ import absolute_import
-from six.moves import range
+from __future__ import division, absolute_import
 
 __copyright__ = "Copyright (C) 2014 Andreas Kloeckner"
 
@@ -24,8 +22,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from six.moves import range
 import numpy as np
-from pytools import memoize_method
+from pytools import memoize_method, Record
 import pyopencl as cl
 
 __doc__ = """
@@ -71,6 +70,34 @@ def separate_by_real_and_imag(data, real_only):
                 yield (name, field)
 
 
+class _VisConnectivityGroup(Record):
+    """
+    .. attribute:: vis_connectivity
+
+        an array of shape ``(group.nelements,nsubelements,primitive_element_size)``
+
+    .. attribute:: vtk_cell_type
+
+    .. attribute:: subelement_nr_base
+    """
+
+    @property
+    def nsubelements(self):
+        return self.nelements * self.nsubelements_per_element
+
+    @property
+    def nelements(self):
+        return self.vis_connectivity.shape[0]
+
+    @property
+    def nsubelements_per_element(self):
+        return self.vis_connectivity.shape[1]
+
+    @property
+    def primitive_element_size(self):
+        return self.vis_connectivity.shape[2]
+
+
 class Visualizer(object):
     """
     .. automethod:: show_scalar_in_mayavi
@@ -93,39 +120,67 @@ class Visualizer(object):
     @memoize_method
     def _vis_connectivity(self):
         """
-        :return: an array of shape
-            ``(vis_discr.nelements,nsubelements,primitive_element_size)``
+        :return: a list of :class:`_VisConnectivityGroup` instances.
         """
         # Assume that we're using modepy's default node ordering.
 
         from pytools import generate_nonnegative_integer_tuples_summing_to_at_most \
                 as gnitstam, single_valued
-        vis_order = single_valued(
-                group.order for group in self.vis_discr.groups)
-        node_tuples = list(gnitstam(vis_order, self.vis_discr.dim))
+        from meshmode.mesh import TensorProductElementGroup, SimplexElementGroup
 
-        from modepy.tools import submesh
-        el_connectivity = np.array(
-                submesh(node_tuples),
-                dtype=np.intp)
+        result = []
 
-        nelements = sum(group.nelements for group in self.vis_discr.groups)
-        vis_connectivity = np.empty(
-                (nelements,) + el_connectivity.shape, dtype=np.intp)
+        from pyvisfile.vtk import (
+                VTK_LINE, VTK_TRIANGLE, VTK_TETRA,
+                VTK_QUAD, VTK_HEXAHEDRON)
+
+        subel_nr_base = 0
 
-        el_nr_base = 0
         for group in self.vis_discr.groups:
+            if isinstance(group.mesh_el_group, SimplexElementGroup):
+                vis_order = single_valued(
+                        group.order for group in self.vis_discr.groups)
+                node_tuples = list(gnitstam(vis_order, self.vis_discr.dim))
+
+                from modepy.tools import submesh
+                el_connectivity = np.array(
+                        submesh(node_tuples),
+                        dtype=np.intp)
+                vtk_cell_type = {
+                        1: VTK_LINE,
+                        2: VTK_TRIANGLE,
+                        3: VTK_TETRA,
+                        }[group.dim]
+
+            elif isinstance(group.mesh_el_group, TensorProductElementGroup):
+                raise NotImplementedError()
+
+                vtk_cell_type = {
+                        1: VTK_LINE,
+                        2: VTK_QUAD,
+                        3: VTK_HEXAHEDRON,
+                        }[group.dim]
+
+            else:
+                raise NotImplementedError("visualization for element groups "
+                        "of type '%s'" % type(group.mesh_el_group).__name__)
+
             assert len(node_tuples) == group.nunit_nodes
-            vis_connectivity[el_nr_base:el_nr_base+group.nelements] = (
-                    np.arange(
-                        el_nr_base*group.nunit_nodes,
-                        (el_nr_base+group.nelements)*group.nunit_nodes,
-                        group.nunit_nodes
+            vis_connectivity = (
+                    group.node_nr_base + np.arange(
+                        0, group.nelements*group.nunit_nodes, group.nunit_nodes
                         )[:, np.newaxis, np.newaxis]
-                    + el_connectivity)
-            el_nr_base += group.nelements
+                    + el_connectivity).astype(np.intp)
+
+            vgrp = _VisConnectivityGroup(
+                vis_connectivity=vis_connectivity,
+                vtk_cell_type=vtk_cell_type,
+                subelement_nr_base=subel_nr_base)
+            result.append(vgrp)
 
-        return vis_connectivity
+            subel_nr_base += vgrp.nsubelements
+
+        return result
 
     def show_scalar_in_mayavi(self, field, **kwargs):
         import mayavi.mlab as mlab
@@ -140,7 +195,7 @@ class Visualizer(object):
         assert nodes.shape[0] == self.vis_discr.ambient_dim
         #mlab.points3d(nodes[0], nodes[1], 0*nodes[0])
 
-        vis_connectivity = self._vis_connectivity()
+        vis_connectivity, = self._vis_connectivity()
 
         if self.vis_discr.dim == 1:
             nodes = list(nodes)
@@ -153,6 +208,7 @@ class Visualizer(object):
 
             # http://docs.enthought.com/mayavi/mayavi/auto/example_plotting_many_lines.html  # noqa
             src = mlab.pipeline.scalar_scatter(*args)
+
             src.mlab_source.dataset.lines = vis_connectivity.reshape(-1, 2)
             lines = mlab.pipeline.stripper(src)
             mlab.pipeline.surface(lines, **kwargs)
@@ -181,15 +237,7 @@ class Visualizer(object):
         from pyvisfile.vtk import (
                 UnstructuredGrid, DataArray,
                 AppendedDataXMLGenerator,
-                VTK_LINE, VTK_TRIANGLE, VTK_TETRA,
                 VF_LIST_OF_COMPONENTS)
-        el_types = {
-                1: VTK_LINE,
-                2: VTK_TRIANGLE,
-                3: VTK_TETRA,
-                }
-
-        el_type = el_types[self.vis_discr.dim]
 
         with cl.CommandQueue(self.vis_discr.cl_context) as queue:
             nodes = self.vis_discr.nodes().with_queue(queue).get()
@@ -198,20 +246,31 @@ class Visualizer(object):
                     (name, self._resample_and_get(queue, fld))
                     for name, fld in names_and_fields]
 
-        connectivity = self._vis_connectivity()
+        vc_groups = self._vis_connectivity()
+
+        # {{{ create cell_types
+
+        nsubelements = sum(vgrp.nsubelements for vgrp in vc_groups)
+        cell_types = np.empty(nsubelements, dtype=np.uint8)
+        cell_types.fill(255)
+        for vgrp in vc_groups:
+            cell_types[
+                    vgrp.subelement_nr_base:
+                    vgrp.subelement_nr_base + vgrp.nsubelements] = \
+                            vgrp.vtk_cell_type
+        assert (cell_types < 255).all()
 
-        nprimitive_elements = (
-                connectivity.shape[0]
-                * connectivity.shape[1])
+        # }}}
 
         grid = UnstructuredGrid(
                 (self.vis_discr.nnodes,
                     DataArray("points",
                         nodes.reshape(self.vis_discr.ambient_dim, -1),
                         vector_format=VF_LIST_OF_COMPONENTS)),
-                cells=connectivity.reshape(-1),
-                cell_types=np.asarray([el_type] * nprimitive_elements,
-                    dtype=np.uint8))
+                cells=np.hstack([
+                    vgrp.vis_connectivity.reshape(-1)
+                    for vgrp in vc_groups]),
+                cell_types=cell_types)
 
         # for name, field in separate_by_real_and_imag(cell_data, real_only):
         #     grid.add_celldata(DataArray(name, field,
-- 
GitLab