From e0d2682e7d555d7783c977fde609803750516f68 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Tue, 27 Sep 2016 00:53:07 -0500
Subject: [PATCH] High order refinement: first cut.

---
 .../discretization/connection/refinement.py   |  36 +----
 meshmode/mesh/refinement/__init__.py          | 132 ++++++++++++------
 meshmode/mesh/refinement/utils.py             |  34 +++++
 test/test_refinement.py                       |  45 ++++--
 4 files changed, 158 insertions(+), 89 deletions(-)

diff --git a/meshmode/discretization/connection/refinement.py b/meshmode/discretization/connection/refinement.py
index 5ef91fe7..ff009fce 100644
--- a/meshmode/discretization/connection/refinement.py
+++ b/meshmode/discretization/connection/refinement.py
@@ -31,39 +31,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-# {{{ Map unit nodes to children
-
-def _map_unit_nodes_to_children(unit_nodes, tesselation):
-    """
-    Given a collection of unit nodes, return the coordinates of the
-    unit nodes mapped onto each of the children of the reference
-    element.
-
-    The tesselation should follow the format of
-    :func:`meshmode.mesh.tesselate.tesselatetri()` or
-    :func:`meshmode.mesh.tesselate.tesselatetet()`.
-
-    `unit_nodes` should be relative to the unit simplex coordinates in
-    :module:`modepy`.
-
-    :arg unit_nodes: shaped `(dim, nunit_nodes)`
-    :arg tesselation: With attributes `ref_vertices`, `children`
-    """
-    ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float)
-
-    assert len(unit_nodes.shape) == 2
-
-    for child_element in tesselation.children:
-        center = np.vstack(ref_vertices[child_element[0]])
-        # Scale by 1/2 since sides in the tesselation have length 2.
-        aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2
-        # (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices.
-        # Hence the translation by +/- 1.
-        yield aff_mat.dot(unit_nodes + 1) + center - 1
-
-# }}}
-
-
 # {{{ Build interpolation batches for group
 
 def _build_interpolation_batches_for_group(
@@ -123,7 +90,8 @@ def _build_interpolation_batches_for_group(
                 to_bin.append(child_idx)
 
     fine_unit_nodes = fine_discr_group.unit_nodes
-    mapped_unit_nodes = _map_unit_nodes_to_children(
+    from meshmode.mesh.refinement.utils import map_unit_nodes_to_children
+    mapped_unit_nodes = map_unit_nodes_to_children(
         fine_unit_nodes, record.tesselation)
 
     from itertools import chain
diff --git a/meshmode/mesh/refinement/__init__.py b/meshmode/mesh/refinement/__init__.py
index 16639442..8d7b5013 100644
--- a/meshmode/mesh/refinement/__init__.py
+++ b/meshmode/mesh/refinement/__init__.py
@@ -26,8 +26,6 @@ import itertools
 from six.moves import range
 from pytools import RecordWithoutPickling
 
-from meshmode.mesh.generation import make_group_from_vertices
-
 
 class TreeRayNode(object):
     """Describes a ray as a tree, this class represents each node in this tree
@@ -217,31 +215,6 @@ class Refiner(object):
         return self.previous_mesh
 
     def get_current_mesh(self):
-
-        from meshmode.mesh import Mesh
-        #return Mesh(vertices, [grp], nodal_adjacency=self.generate_nodal_adjacency(len(self.last_mesh.groups[group].vertex_indices) \
-        #            + count*3))
-        groups = []
-        grpn = 0
-        for grp in self.last_mesh.groups:
-            groups.append(np.empty([len(grp.vertex_indices),
-                len(self.last_mesh.groups[grpn].vertex_indices[0])], dtype=np.int32))
-            for iel_grp in range(grp.nelements):
-                for i in range(0, len(grp.vertex_indices[iel_grp])):
-                    groups[grpn][iel_grp][i] = grp.vertex_indices[iel_grp][i]
-            grpn += 1
-        grp = []
-
-        for grpn in range(0, len(groups)):
-            grp.append(make_group_from_vertices(self.last_mesh.vertices, groups[grpn], 4))
-        self.last_mesh = Mesh(
-                self.last_mesh.vertices, grp,
-                nodal_adjacency=self.generate_nodal_adjacency(
-                    len(self.last_mesh.groups[0].vertex_indices),
-                    len(self.last_mesh.vertices[0])),
-                vertex_id_dtype=self.last_mesh.vertex_id_dtype,
-                element_id_dtype=self.last_mesh.element_id_dtype)
-
         return self.last_mesh
 
     def get_leaves(self, cur_node):
@@ -590,15 +563,42 @@ class Refiner(object):
             element_mapping = []
             tesselation = None
 
+            # {{{ get midpoint coordinates for vertices
+
+            midpoints_to_find = []
+            resampler = None
+            for iel_grp in range(grp.nelements):
+                if refine_flags[iel_base + iel_grp]:
+                    # if simplex
+                    if len(grp.vertex_indices[iel_grp]) == grp.dim + 1:
+                        midpoints_to_find.append(iel_grp)
+                        if not resampler:
+                            from meshmode.mesh.refinement.resampler import (
+                                SimplexResampler)
+                            resampler = SimplexResampler()
+                            tesselation = self._Tesselation(
+                                self.simplex_result[grp.dim],
+                                self.simplex_node_tuples[grp.dim])
+                    else:
+                        raise NotImplementedError("unimplemented: midpoint finding"
+                                                  "for non simplex elements")
+
+            if midpoints_to_find:
+                midpoints = resampler.get_midpoints(
+                        grp, tesselation, midpoints_to_find)
+                midpoint_order = resampler.get_vertex_pair_to_midpoint_order(grp.dim)
+
+            del midpoints_to_find
+
+            # }}}
+
             for iel_grp in range(grp.nelements):
                 element_mapping.append([iel_grp])
                 if refine_flags[iel_base+iel_grp]:
                     midpoint_vertices = []
                     vertex_indices = grp.vertex_indices[iel_grp]
-                    #if simplex
+                    # if simplex
                     if len(vertex_indices) == grp.dim + 1:
-                        # {{{ Get midpoints for all pairs of vertices
-
                         for i in range(len(vertex_indices)):
                             for j in range(i+1, len(vertex_indices)):
                                 min_index = min(vertex_indices[i], vertex_indices[j])
@@ -617,18 +617,15 @@ class Refiner(object):
                                     vertex_pair2 = (max_index, vertices_index)
                                     self.pair_map[vertex_pair1] = cur_node.left
                                     self.pair_map[vertex_pair2] = cur_node.right
-                                    for k in range(len(self.last_mesh.vertices)):
-                                        vertices[k, vertices_index] = \
-                                        (self.last_mesh.vertices[k, vertex_indices[i]] +
-                                        self.last_mesh.vertices[k, vertex_indices[j]]) / 2.0
+                                    midpoint_idx = midpoint_order[(i, j)]
+                                    vertices[:, vertices_index] = \
+                                            midpoints[iel_grp][:,midpoint_idx]
                                     midpoint_vertices.append(vertices_index)
                                     vertices_index += 1
                                 else:
                                     cur_midpoint = cur_node.midpoint
                                     midpoint_vertices.append(cur_midpoint)
 
-                        # }}}
-
                         #generate new rays
                         cur_dim = len(grp.vertex_indices[0])-1
                         for i in range(len(midpoint_vertices)):
@@ -653,12 +650,12 @@ class Refiner(object):
                             for j in range(len(self.simplex_result[cur_dim][i])):
                                 groups[grpn][iel][j] = \
                                     node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
-                        tesselation = self._Tesselation(
-                            self.simplex_result[cur_dim], self.simplex_node_tuples[cur_dim])
                         nelements_in_grp += len(self.simplex_result[cur_dim])-1
                     #assuming quad otherwise
-                    #else:
+                    else:
                         #quadrilateral
+                        raise NotImplementedError("unimplemented: "
+                                                  "support for quad elements")
 #                        node_tuple_to_coord = {}
 #                        for node_index, node_tuple in enumerate(self.index_to_node_tuple[cur_dim]):
 #                            node_tuple_to_coord[node_tuple] = grp.vertex_indices[iel_grp][node_index]
@@ -699,15 +696,64 @@ class Refiner(object):
         #check_adjacent_elements(groups, new_hanging_vertex_element, nelements_in_grp)
 
         self.hanging_vertex_element = new_hanging_vertex_element
-        grp = []
-        for grpn in range(0, len(groups)):
-            grp.append(make_group_from_vertices(vertices, groups[grpn], 4))
+
+        # {{{ make new groups
+
+        new_mesh_el_groups = []
+
+        for refinement_record, group, prev_group in zip(
+                self.group_refinement_records, groups, self.last_mesh.groups):
+            is_simplex = len(prev_group.vertex_indices[0]) == prev_group.dim + 1
+            ambient_dim = len(prev_group.nodes)
+            nelements = len(group)
+            nunit_nodes = len(prev_group.unit_nodes[0])
+
+            nodes = np.empty(
+                (ambient_dim, nelements, nunit_nodes),
+                dtype=prev_group.nodes.dtype)
+
+            element_mapping = refinement_record.element_mapping
+
+            to_resample = [elem for elem in range(len(element_mapping))
+                           if len(element_mapping[elem]) > 1]
+
+            if to_resample:
+                # if simplex
+                if is_simplex:
+                    from meshmode.mesh.refinement.resampler import SimplexResampler
+                    resampler = SimplexResampler()
+                    new_nodes = resampler.get_tesselated_nodes(
+                        prev_group, refinement_record.tesselation, to_resample)
+                else:
+                    raise NotImplementedError(
+                        "unimplemented: node resampling for non simplex elements")
+
+            for elem, mapped_elems in enumerate(element_mapping):
+                if len(mapped_elems) == 1:
+                    # No resampling required, just copy over
+                    nodes[:,mapped_elems[0]] = prev_group.nodes[:, elem]
+                    n = nodes[:,mapped_elems[0]]
+                else:
+                    nodes[:, mapped_elems] = new_nodes[elem]
+
+            if is_simplex:
+                new_mesh_el_groups.append(
+                    type(prev_group)(
+                        order=prev_group.order,
+                        vertex_indices=group,
+                        nodes=nodes,
+                        unit_nodes=prev_group.unit_nodes))
+            else:
+                raise NotImplementedError("unimplemented: support for creating"
+                                          "non simplex element groups")
+
+        # }}}
 
         from meshmode.mesh import Mesh
 
         self.previous_mesh = self.last_mesh
         self.last_mesh = Mesh(
-                vertices, grp,
+                vertices, new_mesh_el_groups,
                 nodal_adjacency=self.generate_nodal_adjacency(
                     totalnelements, nvertices, groups),
                 vertex_id_dtype=self.last_mesh.vertex_id_dtype,
diff --git a/meshmode/mesh/refinement/utils.py b/meshmode/mesh/refinement/utils.py
index 751142aa..08b71f87 100644
--- a/meshmode/mesh/refinement/utils.py
+++ b/meshmode/mesh/refinement/utils.py
@@ -29,8 +29,42 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+# {{{ map unit nodes to children
+
+
+def map_unit_nodes_to_children(unit_nodes, tesselation):
+    """
+    Given a collection of unit nodes, return the coordinates of the
+    unit nodes mapped onto each of the children of the reference
+    element.
+
+    The tesselation should follow the format of
+    :func:`meshmode.mesh.tesselate.tesselatetri()` or
+    :func:`meshmode.mesh.tesselate.tesselatetet()`.
+
+    `unit_nodes` should be relative to the unit simplex coordinates in
+    :module:`modepy`.
+
+    :arg unit_nodes: shaped `(dim, nunit_nodes)`
+    :arg tesselation: With attributes `ref_vertices`, `children`
+    """
+    ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float)
+
+    assert len(unit_nodes.shape) == 2
+
+    for child_element in tesselation.children:
+        center = np.vstack(ref_vertices[child_element[0]])
+        # Scale by 1/2 since sides in the tesselation have length 2.
+        aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2
+        # (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices.
+        # Hence the translation by +/- 1.
+        yield aff_mat.dot(unit_nodes + 1) + center - 1
+
+# }}}
+
 # {{{ test nodal adjacency against geometry
 
+
 def is_symmetric(relation, debug=False):
     for a, other_list in enumerate(relation):
         for b in other_list:
diff --git a/test/test_refinement.py b/test/test_refinement.py
index bed61eb9..cdb7ef4b 100644
--- a/test/test_refinement.py
+++ b/test/test_refinement.py
@@ -47,10 +47,10 @@ logger = logging.getLogger(__name__)
 from functools import partial
 
 
-def gen_blob_mesh(h=0.2):
+def gen_blob_mesh(h=0.2, order=1):
     from meshmode.mesh.io import generate_gmsh, FileSource
     return generate_gmsh(
-            FileSource("blob-2d.step"), 2, order=1,
+            FileSource("blob-2d.step"), 2, order=order,
             force_ambient_dim=2,
             other_options=[
                 "-string", "Mesh.CharacteristicLengthMax = %s;" % h]
@@ -147,21 +147,26 @@ def test_refinement(case_name, mesh_gen, flag_gen, num_generations):
     PolynomialEquidistantGroupFactory
     ])
 @pytest.mark.parametrize(("mesh_name", "dim", "mesh_pars"), [
-    ("circle", 1, [20, 30, 40]),
+    ("circle", 1, [10, 20, 30]),
     ("blob", 2, [1e-1, 8e-2, 5e-2]),
-    ("warp", 2, [4, 5, 6]),
+    ("warp", 2, [7, 8, 9]),
     ("warp", 3, [4, 5, 6]),
 ])
+@pytest.mark.parametrize("mesh_order", [1, 5])
 @pytest.mark.parametrize("refine_flags", [
     # FIXME: slow
-    # uniform_refine_flags,
+    #uniform_refine_flags,
     partial(random_refine_flags, 0.4)
 ])
 def test_refinement_connection(
-        ctx_getter, group_factory, mesh_name, dim, mesh_pars, refine_flags):
+        ctx_getter, group_factory, mesh_name, dim, mesh_pars, mesh_order,
+        refine_flags, plot_mesh=False):
     from random import seed
     seed(13)
 
+    # Discretization order
+    order = 5
+
     cl_ctx = ctx_getter()
     queue = cl.CommandQueue(cl_ctx)
 
@@ -172,8 +177,6 @@ def test_refinement_connection(
     from pytools.convergence import EOCRecorder
     eoc_rec = EOCRecorder()
 
-    order = 5
-
     def f(x):
         from six.moves import reduce
         return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x)
@@ -185,14 +188,17 @@ def test_refinement_connection(
             assert dim == 1
             h = 1 / mesh_par
             mesh = make_curve_mesh(
-                partial(ellipse, 1), np.linspace(0, 1, mesh_par + 1), order=1)
+                partial(ellipse, 1), np.linspace(0, 1, mesh_par + 1),
+                order=mesh_order)
         elif mesh_name == "blob":
+            if mesh_order == 5:
+                pytest.xfail("")
             assert dim == 2
             h = mesh_par
-            mesh = gen_blob_mesh(h)
+            mesh = gen_blob_mesh(h, mesh_order)
         elif mesh_name == "warp":
             from meshmode.mesh.generation import generate_warped_rect_mesh
-            mesh = generate_warped_rect_mesh(dim, order=1, n=mesh_par)
+            mesh = generate_warped_rect_mesh(dim, order=mesh_order, n=mesh_par)
             h = 1/mesh_par
         else:
             raise ValueError("mesh_name not recognized")
@@ -202,7 +208,9 @@ def test_refinement_connection(
         discr = Discretization(cl_ctx, mesh, group_factory(order))
 
         refiner = Refiner(mesh)
-        refiner.refine(refine_flags(mesh))
+        flags = refine_flags(mesh)
+        refiner.refine(flags)
+
         connection = make_refinement_connection(
             refiner, discr, group_factory(order))
         check_connection(connection)
@@ -215,6 +223,19 @@ def test_refinement_connection(
         f_interp = connection(queue, f_coarse).with_queue(queue)
         f_true = f(x_fine).with_queue(queue)
 
+        if plot_mesh:
+            import matplotlib.pyplot as plt
+            x = x.get(queue)
+            err = np.array(np.log10(
+                1e-16 + np.abs((f_interp - f_true).get(queue))), dtype=float)
+            import matplotlib.cm as cm
+            cmap = cm.ScalarMappable(cmap=cm.jet)
+            cmap.set_array(err)
+            #norm = plt.matplotlib.colors.Normalize(vmin=min(err), vmax=max(err))
+            plt.scatter(x[0], x[1], c=cmap.to_rgba(err), s=20, cmap=cmap)
+            plt.colorbar(cmap)
+            plt.show()
+
         import numpy.linalg as la
         err = la.norm((f_interp - f_true).get(queue), np.inf)
         eoc_rec.add_data_point(h, err)
-- 
GitLab