Skip to content
Snippets Groups Projects
Commit e0d2682e authored by Matt Wala's avatar Matt Wala
Browse files

High order refinement: first cut.

parent efc7338c
Branches
Tags
1 merge request!5High order refiner
Pipeline #
...@@ -31,39 +31,6 @@ import logging ...@@ -31,39 +31,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# {{{ Map unit nodes to children
def _map_unit_nodes_to_children(unit_nodes, tesselation):
"""
Given a collection of unit nodes, return the coordinates of the
unit nodes mapped onto each of the children of the reference
element.
The tesselation should follow the format of
:func:`meshmode.mesh.tesselate.tesselatetri()` or
:func:`meshmode.mesh.tesselate.tesselatetet()`.
`unit_nodes` should be relative to the unit simplex coordinates in
:module:`modepy`.
:arg unit_nodes: shaped `(dim, nunit_nodes)`
:arg tesselation: With attributes `ref_vertices`, `children`
"""
ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float)
assert len(unit_nodes.shape) == 2
for child_element in tesselation.children:
center = np.vstack(ref_vertices[child_element[0]])
# Scale by 1/2 since sides in the tesselation have length 2.
aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2
# (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices.
# Hence the translation by +/- 1.
yield aff_mat.dot(unit_nodes + 1) + center - 1
# }}}
# {{{ Build interpolation batches for group # {{{ Build interpolation batches for group
def _build_interpolation_batches_for_group( def _build_interpolation_batches_for_group(
...@@ -123,7 +90,8 @@ def _build_interpolation_batches_for_group( ...@@ -123,7 +90,8 @@ def _build_interpolation_batches_for_group(
to_bin.append(child_idx) to_bin.append(child_idx)
fine_unit_nodes = fine_discr_group.unit_nodes fine_unit_nodes = fine_discr_group.unit_nodes
mapped_unit_nodes = _map_unit_nodes_to_children( from meshmode.mesh.refinement.utils import map_unit_nodes_to_children
mapped_unit_nodes = map_unit_nodes_to_children(
fine_unit_nodes, record.tesselation) fine_unit_nodes, record.tesselation)
from itertools import chain from itertools import chain
......
...@@ -26,8 +26,6 @@ import itertools ...@@ -26,8 +26,6 @@ import itertools
from six.moves import range from six.moves import range
from pytools import RecordWithoutPickling from pytools import RecordWithoutPickling
from meshmode.mesh.generation import make_group_from_vertices
class TreeRayNode(object): class TreeRayNode(object):
"""Describes a ray as a tree, this class represents each node in this tree """Describes a ray as a tree, this class represents each node in this tree
...@@ -217,31 +215,6 @@ class Refiner(object): ...@@ -217,31 +215,6 @@ class Refiner(object):
return self.previous_mesh return self.previous_mesh
def get_current_mesh(self): def get_current_mesh(self):
from meshmode.mesh import Mesh
#return Mesh(vertices, [grp], nodal_adjacency=self.generate_nodal_adjacency(len(self.last_mesh.groups[group].vertex_indices) \
# + count*3))
groups = []
grpn = 0
for grp in self.last_mesh.groups:
groups.append(np.empty([len(grp.vertex_indices),
len(self.last_mesh.groups[grpn].vertex_indices[0])], dtype=np.int32))
for iel_grp in range(grp.nelements):
for i in range(0, len(grp.vertex_indices[iel_grp])):
groups[grpn][iel_grp][i] = grp.vertex_indices[iel_grp][i]
grpn += 1
grp = []
for grpn in range(0, len(groups)):
grp.append(make_group_from_vertices(self.last_mesh.vertices, groups[grpn], 4))
self.last_mesh = Mesh(
self.last_mesh.vertices, grp,
nodal_adjacency=self.generate_nodal_adjacency(
len(self.last_mesh.groups[0].vertex_indices),
len(self.last_mesh.vertices[0])),
vertex_id_dtype=self.last_mesh.vertex_id_dtype,
element_id_dtype=self.last_mesh.element_id_dtype)
return self.last_mesh return self.last_mesh
def get_leaves(self, cur_node): def get_leaves(self, cur_node):
...@@ -590,15 +563,42 @@ class Refiner(object): ...@@ -590,15 +563,42 @@ class Refiner(object):
element_mapping = [] element_mapping = []
tesselation = None tesselation = None
# {{{ get midpoint coordinates for vertices
midpoints_to_find = []
resampler = None
for iel_grp in range(grp.nelements):
if refine_flags[iel_base + iel_grp]:
# if simplex
if len(grp.vertex_indices[iel_grp]) == grp.dim + 1:
midpoints_to_find.append(iel_grp)
if not resampler:
from meshmode.mesh.refinement.resampler import (
SimplexResampler)
resampler = SimplexResampler()
tesselation = self._Tesselation(
self.simplex_result[grp.dim],
self.simplex_node_tuples[grp.dim])
else:
raise NotImplementedError("unimplemented: midpoint finding"
"for non simplex elements")
if midpoints_to_find:
midpoints = resampler.get_midpoints(
grp, tesselation, midpoints_to_find)
midpoint_order = resampler.get_vertex_pair_to_midpoint_order(grp.dim)
del midpoints_to_find
# }}}
for iel_grp in range(grp.nelements): for iel_grp in range(grp.nelements):
element_mapping.append([iel_grp]) element_mapping.append([iel_grp])
if refine_flags[iel_base+iel_grp]: if refine_flags[iel_base+iel_grp]:
midpoint_vertices = [] midpoint_vertices = []
vertex_indices = grp.vertex_indices[iel_grp] vertex_indices = grp.vertex_indices[iel_grp]
#if simplex # if simplex
if len(vertex_indices) == grp.dim + 1: if len(vertex_indices) == grp.dim + 1:
# {{{ Get midpoints for all pairs of vertices
for i in range(len(vertex_indices)): for i in range(len(vertex_indices)):
for j in range(i+1, len(vertex_indices)): for j in range(i+1, len(vertex_indices)):
min_index = min(vertex_indices[i], vertex_indices[j]) min_index = min(vertex_indices[i], vertex_indices[j])
...@@ -617,18 +617,15 @@ class Refiner(object): ...@@ -617,18 +617,15 @@ class Refiner(object):
vertex_pair2 = (max_index, vertices_index) vertex_pair2 = (max_index, vertices_index)
self.pair_map[vertex_pair1] = cur_node.left self.pair_map[vertex_pair1] = cur_node.left
self.pair_map[vertex_pair2] = cur_node.right self.pair_map[vertex_pair2] = cur_node.right
for k in range(len(self.last_mesh.vertices)): midpoint_idx = midpoint_order[(i, j)]
vertices[k, vertices_index] = \ vertices[:, vertices_index] = \
(self.last_mesh.vertices[k, vertex_indices[i]] + midpoints[iel_grp][:,midpoint_idx]
self.last_mesh.vertices[k, vertex_indices[j]]) / 2.0
midpoint_vertices.append(vertices_index) midpoint_vertices.append(vertices_index)
vertices_index += 1 vertices_index += 1
else: else:
cur_midpoint = cur_node.midpoint cur_midpoint = cur_node.midpoint
midpoint_vertices.append(cur_midpoint) midpoint_vertices.append(cur_midpoint)
# }}}
#generate new rays #generate new rays
cur_dim = len(grp.vertex_indices[0])-1 cur_dim = len(grp.vertex_indices[0])-1
for i in range(len(midpoint_vertices)): for i in range(len(midpoint_vertices)):
...@@ -653,12 +650,12 @@ class Refiner(object): ...@@ -653,12 +650,12 @@ class Refiner(object):
for j in range(len(self.simplex_result[cur_dim][i])): for j in range(len(self.simplex_result[cur_dim][i])):
groups[grpn][iel][j] = \ groups[grpn][iel][j] = \
node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]] node_tuple_to_coord[self.simplex_node_tuples[cur_dim][self.simplex_result[cur_dim][i][j]]]
tesselation = self._Tesselation(
self.simplex_result[cur_dim], self.simplex_node_tuples[cur_dim])
nelements_in_grp += len(self.simplex_result[cur_dim])-1 nelements_in_grp += len(self.simplex_result[cur_dim])-1
#assuming quad otherwise #assuming quad otherwise
#else: else:
#quadrilateral #quadrilateral
raise NotImplementedError("unimplemented: "
"support for quad elements")
# node_tuple_to_coord = {} # node_tuple_to_coord = {}
# for node_index, node_tuple in enumerate(self.index_to_node_tuple[cur_dim]): # for node_index, node_tuple in enumerate(self.index_to_node_tuple[cur_dim]):
# node_tuple_to_coord[node_tuple] = grp.vertex_indices[iel_grp][node_index] # node_tuple_to_coord[node_tuple] = grp.vertex_indices[iel_grp][node_index]
...@@ -699,15 +696,64 @@ class Refiner(object): ...@@ -699,15 +696,64 @@ class Refiner(object):
#check_adjacent_elements(groups, new_hanging_vertex_element, nelements_in_grp) #check_adjacent_elements(groups, new_hanging_vertex_element, nelements_in_grp)
self.hanging_vertex_element = new_hanging_vertex_element self.hanging_vertex_element = new_hanging_vertex_element
grp = []
for grpn in range(0, len(groups)): # {{{ make new groups
grp.append(make_group_from_vertices(vertices, groups[grpn], 4))
new_mesh_el_groups = []
for refinement_record, group, prev_group in zip(
self.group_refinement_records, groups, self.last_mesh.groups):
is_simplex = len(prev_group.vertex_indices[0]) == prev_group.dim + 1
ambient_dim = len(prev_group.nodes)
nelements = len(group)
nunit_nodes = len(prev_group.unit_nodes[0])
nodes = np.empty(
(ambient_dim, nelements, nunit_nodes),
dtype=prev_group.nodes.dtype)
element_mapping = refinement_record.element_mapping
to_resample = [elem for elem in range(len(element_mapping))
if len(element_mapping[elem]) > 1]
if to_resample:
# if simplex
if is_simplex:
from meshmode.mesh.refinement.resampler import SimplexResampler
resampler = SimplexResampler()
new_nodes = resampler.get_tesselated_nodes(
prev_group, refinement_record.tesselation, to_resample)
else:
raise NotImplementedError(
"unimplemented: node resampling for non simplex elements")
for elem, mapped_elems in enumerate(element_mapping):
if len(mapped_elems) == 1:
# No resampling required, just copy over
nodes[:,mapped_elems[0]] = prev_group.nodes[:, elem]
n = nodes[:,mapped_elems[0]]
else:
nodes[:, mapped_elems] = new_nodes[elem]
if is_simplex:
new_mesh_el_groups.append(
type(prev_group)(
order=prev_group.order,
vertex_indices=group,
nodes=nodes,
unit_nodes=prev_group.unit_nodes))
else:
raise NotImplementedError("unimplemented: support for creating"
"non simplex element groups")
# }}}
from meshmode.mesh import Mesh from meshmode.mesh import Mesh
self.previous_mesh = self.last_mesh self.previous_mesh = self.last_mesh
self.last_mesh = Mesh( self.last_mesh = Mesh(
vertices, grp, vertices, new_mesh_el_groups,
nodal_adjacency=self.generate_nodal_adjacency( nodal_adjacency=self.generate_nodal_adjacency(
totalnelements, nvertices, groups), totalnelements, nvertices, groups),
vertex_id_dtype=self.last_mesh.vertex_id_dtype, vertex_id_dtype=self.last_mesh.vertex_id_dtype,
......
...@@ -29,8 +29,42 @@ import logging ...@@ -29,8 +29,42 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# {{{ map unit nodes to children
def map_unit_nodes_to_children(unit_nodes, tesselation):
"""
Given a collection of unit nodes, return the coordinates of the
unit nodes mapped onto each of the children of the reference
element.
The tesselation should follow the format of
:func:`meshmode.mesh.tesselate.tesselatetri()` or
:func:`meshmode.mesh.tesselate.tesselatetet()`.
`unit_nodes` should be relative to the unit simplex coordinates in
:module:`modepy`.
:arg unit_nodes: shaped `(dim, nunit_nodes)`
:arg tesselation: With attributes `ref_vertices`, `children`
"""
ref_vertices = np.array(tesselation.ref_vertices, dtype=np.float)
assert len(unit_nodes.shape) == 2
for child_element in tesselation.children:
center = np.vstack(ref_vertices[child_element[0]])
# Scale by 1/2 since sides in the tesselation have length 2.
aff_mat = (ref_vertices.T[:, child_element[1:]] - center) / 2
# (-1, -1, ...) in unit_nodes = (0, 0, ...) in ref_vertices.
# Hence the translation by +/- 1.
yield aff_mat.dot(unit_nodes + 1) + center - 1
# }}}
# {{{ test nodal adjacency against geometry # {{{ test nodal adjacency against geometry
def is_symmetric(relation, debug=False): def is_symmetric(relation, debug=False):
for a, other_list in enumerate(relation): for a, other_list in enumerate(relation):
for b in other_list: for b in other_list:
......
...@@ -47,10 +47,10 @@ logger = logging.getLogger(__name__) ...@@ -47,10 +47,10 @@ logger = logging.getLogger(__name__)
from functools import partial from functools import partial
def gen_blob_mesh(h=0.2): def gen_blob_mesh(h=0.2, order=1):
from meshmode.mesh.io import generate_gmsh, FileSource from meshmode.mesh.io import generate_gmsh, FileSource
return generate_gmsh( return generate_gmsh(
FileSource("blob-2d.step"), 2, order=1, FileSource("blob-2d.step"), 2, order=order,
force_ambient_dim=2, force_ambient_dim=2,
other_options=[ other_options=[
"-string", "Mesh.CharacteristicLengthMax = %s;" % h] "-string", "Mesh.CharacteristicLengthMax = %s;" % h]
...@@ -147,21 +147,26 @@ def test_refinement(case_name, mesh_gen, flag_gen, num_generations): ...@@ -147,21 +147,26 @@ def test_refinement(case_name, mesh_gen, flag_gen, num_generations):
PolynomialEquidistantGroupFactory PolynomialEquidistantGroupFactory
]) ])
@pytest.mark.parametrize(("mesh_name", "dim", "mesh_pars"), [ @pytest.mark.parametrize(("mesh_name", "dim", "mesh_pars"), [
("circle", 1, [20, 30, 40]), ("circle", 1, [10, 20, 30]),
("blob", 2, [1e-1, 8e-2, 5e-2]), ("blob", 2, [1e-1, 8e-2, 5e-2]),
("warp", 2, [4, 5, 6]), ("warp", 2, [7, 8, 9]),
("warp", 3, [4, 5, 6]), ("warp", 3, [4, 5, 6]),
]) ])
@pytest.mark.parametrize("mesh_order", [1, 5])
@pytest.mark.parametrize("refine_flags", [ @pytest.mark.parametrize("refine_flags", [
# FIXME: slow # FIXME: slow
# uniform_refine_flags, #uniform_refine_flags,
partial(random_refine_flags, 0.4) partial(random_refine_flags, 0.4)
]) ])
def test_refinement_connection( def test_refinement_connection(
ctx_getter, group_factory, mesh_name, dim, mesh_pars, refine_flags): ctx_getter, group_factory, mesh_name, dim, mesh_pars, mesh_order,
refine_flags, plot_mesh=False):
from random import seed from random import seed
seed(13) seed(13)
# Discretization order
order = 5
cl_ctx = ctx_getter() cl_ctx = ctx_getter()
queue = cl.CommandQueue(cl_ctx) queue = cl.CommandQueue(cl_ctx)
...@@ -172,8 +177,6 @@ def test_refinement_connection( ...@@ -172,8 +177,6 @@ def test_refinement_connection(
from pytools.convergence import EOCRecorder from pytools.convergence import EOCRecorder
eoc_rec = EOCRecorder() eoc_rec = EOCRecorder()
order = 5
def f(x): def f(x):
from six.moves import reduce from six.moves import reduce
return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x) return 0.1 * reduce(lambda x, y: x * cl.clmath.sin(5 * y), x)
...@@ -185,14 +188,17 @@ def test_refinement_connection( ...@@ -185,14 +188,17 @@ def test_refinement_connection(
assert dim == 1 assert dim == 1
h = 1 / mesh_par h = 1 / mesh_par
mesh = make_curve_mesh( mesh = make_curve_mesh(
partial(ellipse, 1), np.linspace(0, 1, mesh_par + 1), order=1) partial(ellipse, 1), np.linspace(0, 1, mesh_par + 1),
order=mesh_order)
elif mesh_name == "blob": elif mesh_name == "blob":
if mesh_order == 5:
pytest.xfail("")
assert dim == 2 assert dim == 2
h = mesh_par h = mesh_par
mesh = gen_blob_mesh(h) mesh = gen_blob_mesh(h, mesh_order)
elif mesh_name == "warp": elif mesh_name == "warp":
from meshmode.mesh.generation import generate_warped_rect_mesh from meshmode.mesh.generation import generate_warped_rect_mesh
mesh = generate_warped_rect_mesh(dim, order=1, n=mesh_par) mesh = generate_warped_rect_mesh(dim, order=mesh_order, n=mesh_par)
h = 1/mesh_par h = 1/mesh_par
else: else:
raise ValueError("mesh_name not recognized") raise ValueError("mesh_name not recognized")
...@@ -202,7 +208,9 @@ def test_refinement_connection( ...@@ -202,7 +208,9 @@ def test_refinement_connection(
discr = Discretization(cl_ctx, mesh, group_factory(order)) discr = Discretization(cl_ctx, mesh, group_factory(order))
refiner = Refiner(mesh) refiner = Refiner(mesh)
refiner.refine(refine_flags(mesh)) flags = refine_flags(mesh)
refiner.refine(flags)
connection = make_refinement_connection( connection = make_refinement_connection(
refiner, discr, group_factory(order)) refiner, discr, group_factory(order))
check_connection(connection) check_connection(connection)
...@@ -215,6 +223,19 @@ def test_refinement_connection( ...@@ -215,6 +223,19 @@ def test_refinement_connection(
f_interp = connection(queue, f_coarse).with_queue(queue) f_interp = connection(queue, f_coarse).with_queue(queue)
f_true = f(x_fine).with_queue(queue) f_true = f(x_fine).with_queue(queue)
if plot_mesh:
import matplotlib.pyplot as plt
x = x.get(queue)
err = np.array(np.log10(
1e-16 + np.abs((f_interp - f_true).get(queue))), dtype=float)
import matplotlib.cm as cm
cmap = cm.ScalarMappable(cmap=cm.jet)
cmap.set_array(err)
#norm = plt.matplotlib.colors.Normalize(vmin=min(err), vmax=max(err))
plt.scatter(x[0], x[1], c=cmap.to_rgba(err), s=20, cmap=cmap)
plt.colorbar(cmap)
plt.show()
import numpy.linalg as la import numpy.linalg as la
err = la.norm((f_interp - f_true).get(queue), np.inf) err = la.norm((f_interp - f_true).get(queue), np.inf)
eoc_rec.add_data_point(h, err) eoc_rec.add_data_point(h, err)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment