From 2b7e1e49d5b369a0ca0261921a6485c65826b3ba Mon Sep 17 00:00:00 2001
From: Shivam Gupta <sgupta72@illinois.edu>
Date: Tue, 28 Apr 2015 17:54:56 -0500
Subject: [PATCH] Fixed tests for mesh in 3d space

---
 test/test_meshmode.py | 78 ++++++++++++++++++++++++++++++++-----------
 1 file changed, 59 insertions(+), 19 deletions(-)

diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index 63a00c30..cacfa3b8 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -407,13 +407,52 @@ def get_vertex(mesh, vertex_index):
     return vertex
 
 
-def get_some_mesh():
-    from meshmode.mesh.generation import generate_regular_rect_mesh
-    return generate_regular_rect_mesh()
+def get_blobby_2d_mesh():
+    from meshmode.mesh.io import generate_gmsh, FileSource
+    return generate_gmsh(
+            FileSource("blob-2d.step"), 2, order=2,
+            force_ambient_dim=2,
+            other_options=[
+                "-string", "Mesh.CharacteristicLengthMax = 1e-1;"]
+            )
+
 
+def get_some_mesh():
+    from meshmode.mesh.generation import (  # noqa
+            generate_icosphere, generate_icosahedron,
+            generate_torus, generate_regular_rect_mesh)
+    import random
+    #mesh = generate_icosphere(1, order=4)
+    #mesh = generate_icosahedron(1, order=4)
+    mesh = generate_torus(3, 1, order=4)
+
+    #mesh = generate_regular_rect_mesh()
+
+    from meshmode.mesh.refinement import Refiner
+    r = Refiner(mesh)
+    times = random.randint(1, 2)
+    for time in xrange(times):
+        flags = np.zeros(len(mesh.groups[0].vertex_indices))
+        '''
+        if time == 0:
+            flags[0] = 1
+        if time == 1:
+            #flags[1] = 1
+            flags[33] = 1
+        if time == 2:
+            flags[3] = 1
+            pass
+            #flags[8] = 1
+            #flags[40] = 1
+        '''
+        for i in xrange(0, len(flags)):
+            flags[i] = random.randint(0, 1)
+        mesh = r.refine(flags)
+    return mesh
 
 @pytest.mark.parametrize("mesh_factory", [get_some_mesh])
-def test_refiner_connectivity(mesh_factory):
+@pytest.mark.parametrize("num_rounds", [1,2,3])
+def test_refiner_connectivity(mesh_factory, num_rounds):
     mesh = mesh_factory()
     def group_and_iel_to_global_iel(igrp, iel):
         return mesh.groups[igrp].element_nr_base + iel
@@ -459,25 +498,25 @@ def test_refiner_connectivity(mesh_factory):
                         enumerate(bounding_grp.vertex_indices[bounding_iel][:nvertices_per_element-1]):
                         bounding_vertex = get_vertex(mesh, bounding_vertex_index)
                         transformation[:,ibounding_vertex_index] = bounding_vertex - last_bounding_vertex
-                    barycentric_coordinates = np.linalg.solve(transformation, vertex_transformed)
-                    is_connected = True
-                    sum_of_coords = 0.0
-                    for coord in barycentric_coordinates:
-                        if coord < 0:
-                            is_connected = False
-                        sum_of_coords += coord
-                    if sum_of_coords > 1:
-                        is_connected = False
-                    if is_connected:
+                    barycentric_coordinates, resid, _, _ = np.linalg.lstsq(transformation, vertex_transformed)
+
+                    bc = barycentric_coordinates
+                    tol = 1e-12
+
+                    if resid > tol:
+                        continue
+
+                    if ((bc > -tol).all() and np.sum(bc) < 1+tol):
                         connected_to_element_geometry[group_and_iel_to_global_iel(bounding_igrp, bounding_iel),\
                                 group_and_iel_to_global_iel(igrp, iel_grp)] = True
                         connected_to_element_geometry[group_and_iel_to_global_iel(igrp, iel_grp),\
                                 group_and_iel_to_global_iel(bounding_igrp, bounding_iel)] = True
+
     '''
-    print "GEOMETRY: "
-    print connected_to_element_geometry 
-    print "CONNECTIVITY: "
-    print connected_to_element_connectivity
+    print ("GEOMETRY: ")
+    print (connected_to_element_geometry)
+    print ("CONNECTIVITY: ")
+    print (connected_to_element_connectivity)
 #    cmpmatrix = (connected_to_element_geometry == connected_to_element_connectivity)
     #print cmpmatrix
 #    print np.where(cmpmatrix == False)
@@ -486,7 +525,8 @@ def test_refiner_connectivity(mesh_factory):
         for ii, i in enumerate(cmpmatrix):
             for ij, j in enumerate(i):
                 if j == False:
-                    print ii, ij
+                    pass
+                    print (ii, ij)
     '''
     assert (connected_to_element_geometry == connected_to_element_connectivity).all()
 
-- 
GitLab