From 202a7da13c9d6b5bc214d5b3478a0e3060c34bb0 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Fri, 8 May 2020 21:12:54 -0500
Subject: [PATCH] make_curve_mesh: remove extra vertex for closed curves

Before, the function was constructing a mesh that had
an extra vertex when the curve was closed. This was inconcistent
with the group's `vertex_indices`.
---
 meshmode/mesh/generation.py | 28 +++++++++++++++++++++-------
 test/test_meshmode.py       | 28 ++++++++++++++++++++++++++++
 2 files changed, 49 insertions(+), 7 deletions(-)

diff --git a/meshmode/mesh/generation.py b/meshmode/mesh/generation.py
index 3540dadd..92c49b7d 100644
--- a/meshmode/mesh/generation.py
+++ b/meshmode/mesh/generation.py
@@ -237,6 +237,7 @@ starfish = NArmedStarfish(5, 0.25)
 def make_curve_mesh(curve_f, element_boundaries, order,
         unit_nodes=None,
         node_vertex_consistency_tolerance=None,
+        closed=True,
         return_parametrization_points=False):
     """
     :arg curve_f: A callable representing a parametrization for a curve,
@@ -245,10 +246,12 @@ def make_curve_mesh(curve_f, element_boundaries, order,
     :arg element_boundaries: a vector of element boundary locations in
         :math:`[0,1]`, in order. 0 must be the first entry, 1 the
         last one.
+    :arg closed: if *True*, the curve is assumed closed and the first and
+        last of the *element_boundaries* must match.
     :arg unit_nodes: if given, the unit nodes to use. Must have shape
-        ``(dim, nnoodes)``.
+        ``(dim, nnodes)``.
     :returns: a :class:`meshmode.mesh.Mesh`, or if *return_parametrization_points*
-        is True, a tuple ``(mesh, par_points)``, where *par_points* is an array of
+        is *True*, a tuple ``(mesh, par_points)``, where *par_points* is an array of
         parametrization points.
     """
 
@@ -260,7 +263,21 @@ def make_curve_mesh(curve_f, element_boundaries, order,
         unit_nodes = mp.warp_and_blend_nodes(1, order)
     nodes_01 = 0.5*(unit_nodes+1)
 
-    vertices = curve_f(element_boundaries)
+    wrap = nelements
+    if not closed:
+        wrap += 1
+
+    vertices = curve_f(element_boundaries)[:, :wrap]
+    vertex_indices = np.vstack([
+        np.arange(0, nelements, dtype=np.int32),
+        np.arange(1, nelements + 1, dtype=np.int32) % wrap
+        ]).T
+
+    assert vertices.shape[1] == np.max(vertex_indices) + 1
+    if closed:
+        assert la.norm(
+                curve_f(element_boundaries[0])
+                - curve_f(element_boundaries[-1])) < 1.0e-14
 
     el_lengths = np.diff(element_boundaries)
     el_starts = element_boundaries[:-1]
@@ -273,10 +290,7 @@ def make_curve_mesh(curve_f, element_boundaries, order,
     from meshmode.mesh import Mesh, SimplexElementGroup
     egroup = SimplexElementGroup(
             order,
-            vertex_indices=np.vstack([
-                np.arange(nelements, dtype=np.int32),
-                np.arange(1, nelements+1, dtype=np.int32) % nelements,
-                ]).T,
+            vertex_indices=vertex_indices,
             nodes=nodes,
             unit_nodes=unit_nodes)
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index c0be60c6..e2ee1a81 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -1177,6 +1177,34 @@ def test_mesh_without_vertices(ctx_factory):
     make_visualizer(queue, discr, 4)
 
 
+@pytest.mark.parametrize("curve_name", ["ellipse", "arc"])
+def test_open_curved_mesh(curve_name):
+    def arc_curve(t, start=0, end=np.pi):
+        return np.vstack([
+            np.cos((end - start) * t + start),
+            np.sin((end - start) * t + start)
+            ])
+
+    if curve_name == "ellipse":
+        from functools import partial
+        from meshmode.mesh.generation import ellipse
+        curve_f = partial(ellipse, 2.0)
+        closed = True
+    elif curve_name == "arc":
+        curve_f = arc_curve
+        closed = False
+    else:
+        raise ValueError("unknown curve")
+
+    from meshmode.mesh.generation import make_curve_mesh
+    nelements = 32
+    order = 4
+    make_curve_mesh(curve_f,
+            np.linspace(0.0, 1.0, nelements + 1),
+            order=order,
+            closed=closed)
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab