From 0d126c155b5f99e89a699dd576061967137de92f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 17 Feb 2015 21:21:38 -0600
Subject: [PATCH] Add 2D connectivity plotting

---
 meshmode/mesh/__init__.py      |  5 ++++-
 meshmode/mesh/visualization.py | 41 +++++++++++++++++++++++++++++++---
 test/test_meshmode.py          |  2 +-
 3 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py
index 2347a87..6bd4b92 100644
--- a/meshmode/mesh/__init__.py
+++ b/meshmode/mesh/__init__.py
@@ -344,7 +344,6 @@ class Mesh(Record):
 
         return self._element_connectivity
 
-
     # Design experience: Try not to add too many global data structures to the
     # mesh. Let the element groups be responsible for that at the mesh level.
     #
@@ -423,6 +422,10 @@ def _compute_connectivity_from_vertices(mesh):
             for ivertex in grp.vertex_indices[iel_grp]:
                 element_to_element[iel_base + iel_grp].update(
                         vertex_to_element[ivertex])
+
+    for iel, neighbors in enumerate(element_to_element):
+        neighbors.remove(iel)
+
     lengths = [len(el_list) for el_list in element_to_element]
     neighbors_starts = np.cumsum(
             np.array([0] + lengths, dtype=mesh.element_id_dtype))
diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py
index 2b19da7..8966861 100644
--- a/meshmode/mesh/visualization.py
+++ b/meshmode/mesh/visualization.py
@@ -30,7 +30,7 @@ import numpy as np
 # {{{ draw_2d_mesh
 
 def draw_2d_mesh(mesh, draw_vertex_numbers=True, draw_element_numbers=True,
-        **kwargs):
+        draw_connectivity=False, **kwargs):
     assert mesh.ambient_dim == 2
 
     import matplotlib.pyplot as pt
@@ -57,8 +57,7 @@ def draw_2d_mesh(mesh, draw_vertex_numbers=True, draw_element_numbers=True,
             pt.gca().add_patch(patch)
 
             if draw_element_numbers:
-                centroid = (np.sum(elverts, axis=1)
-                        / elverts.shape[1])
+                centroid = (np.sum(elverts, axis=1) / elverts.shape[1])
 
                 if len(mesh.groups) == 1:
                     el_label = str(iel)
@@ -75,6 +74,42 @@ def draw_2d_mesh(mesh, draw_vertex_numbers=True, draw_element_numbers=True,
                     ha="center", va="center", color="blue",
                     bbox=dict(facecolor='white', alpha=0.5, lw=0))
 
+    if draw_connectivity:
+        def global_iel_to_group_and_iel(global_iel):
+            for igrp, grp in enumerate(mesh.groups):
+                if global_iel < grp.nelements:
+                    return grp, global_iel
+                global_iel -= grp.nelements
+
+            raise ValueError("invalid element nr")
+
+        cnx = mesh.element_connectivity
+
+        nb_starts = cnx.neighbors_starts
+        for iel_g in range(mesh.nelements):
+            for nb_iel_g in cnx.neighbors[nb_starts[iel_g]:nb_starts[iel_g+1]]:
+                assert iel_g != nb_iel_g
+
+                grp, iel = global_iel_to_group_and_iel(iel_g)
+                nb_grp, nb_iel = global_iel_to_group_and_iel(nb_iel_g)
+
+                elverts = mesh.vertices[:, grp.vertex_indices[iel]]
+                nb_elverts = mesh.vertices[:, nb_grp.vertex_indices[nb_iel]]
+
+                centroid = (np.sum(elverts, axis=1) / elverts.shape[1])
+                nb_centroid = (np.sum(nb_elverts, axis=1) / nb_elverts.shape[1])
+
+                dx = nb_centroid - centroid
+                start = centroid + 0.15*dx
+
+                mag = np.max(np.abs(dx))
+                start += 0.05*(np.random.rand(2)-0.5)*mag
+                dx += 0.05*(np.random.rand(2)-0.5)*mag
+
+                pt.arrow(start[0], start[1], 0.7*dx[0], 0.7*dx[1],
+                        length_includes_head=True,
+                        color="green", head_width=1e-2, lw=1e-2)
+
 # }}}
 
 # vim: foldmethod=marker
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 3452105..6047113 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -351,7 +351,7 @@ def test_rect_mesh(do_plot=False):
 
     if do_plot:
         from meshmode.mesh.visualization import draw_2d_mesh
-        draw_2d_mesh(mesh, fill=None)
+        draw_2d_mesh(mesh, fill=None, draw_connectivity=True)
         import matplotlib.pyplot as pt
         pt.show()
 
-- 
GitLab