From 36f1c95f481382453c2cc61ede1ed1bd4e6f8725 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 12 Aug 2016 14:25:27 -0500
Subject: [PATCH] make_curve_mesh: Allow specifying unit nodes and
 node_vertex_consistency_tolerance

---
 meshmode/mesh/generation.py | 31 +++++++++++++++++++++++--------
 1 file changed, 23 insertions(+), 8 deletions(-)

diff --git a/meshmode/mesh/generation.py b/meshmode/mesh/generation.py
index 65761586..ce2a5f23 100644
--- a/meshmode/mesh/generation.py
+++ b/meshmode/mesh/generation.py
@@ -175,7 +175,10 @@ def qbx_peanut(t):
 
 # {{{ make_curve_mesh
 
-def make_curve_mesh(curve_f, element_boundaries, order):
+def make_curve_mesh(curve_f, element_boundaries, order,
+        unit_nodes=None,
+        node_vertex_consistency_tolerance=None,
+        return_parametrization_points=False):
     """
     :arg curve_f: A callable representing a parametrization for a curve,
         accepting a vector of point locations and returning
@@ -183,15 +186,20 @@ def make_curve_mesh(curve_f, element_boundaries, order):
     :arg element_boundaries: a vector of element boundary locations in
         :math:`[0,1]`, in order. 0 must be the first entry, 1 the
         last one.
-    :returns: a :class:`meshmode.mesh.Mesh`
+    :arg unit_nodes: if given, the unit nodes to use. Must have shape
+        ``(dim, nnoodes)``.
+    :returns: a :class:`meshmode.mesh.Mesh`, or if *return_parametrization_points*
+        is True, a tuple ``(mesh, par_points)``, where *par_points* is an array of
+        parametrization points.
     """
 
     assert element_boundaries[0] == 0
     assert element_boundaries[-1] == 1
     nelements = len(element_boundaries) - 1
 
-    unodes = mp.warp_and_blend_nodes(1, order)
-    nodes_01 = 0.5*(unodes+1)
+    if unit_nodes is None:
+        unit_nodes = mp.warp_and_blend_nodes(1, order)
+    nodes_01 = 0.5*(unit_nodes+1)
 
     vertices = curve_f(element_boundaries)
 
@@ -200,7 +208,8 @@ def make_curve_mesh(curve_f, element_boundaries, order):
 
     # (el_nr, node_nr)
     t = el_starts[:, np.newaxis] + el_lengths[:, np.newaxis]*nodes_01
-    nodes = curve_f(t.ravel()).reshape(vertices.shape[0], nelements, -1)
+    t = t.ravel()
+    nodes = curve_f(t).reshape(vertices.shape[0], nelements, -1)
 
     from meshmode.mesh import Mesh, SimplexElementGroup
     egroup = SimplexElementGroup(
@@ -210,12 +219,18 @@ def make_curve_mesh(curve_f, element_boundaries, order):
                 np.arange(1, nelements+1, dtype=np.int32) % nelements,
                 ]).T,
             nodes=nodes,
-            unit_nodes=unodes)
+            unit_nodes=unit_nodes)
 
-    return Mesh(
+    mesh = Mesh(
             vertices=vertices, groups=[egroup],
             nodal_adjacency=None,
-            facial_adjacency_groups=None)
+            facial_adjacency_groups=None,
+            node_vertex_consistency_tolerance=node_vertex_consistency_tolerance)
+
+    if return_parametrization_points:
+        return mesh, t
+    else:
+        return mesh
 
 # }}}
 
-- 
GitLab