Skip to content
Snippets Groups Projects
Commit 16e17c1e authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Make generate_box_mesh smart enough to use tensor product elements

parent 1d63c145
No related branches found
No related tags found
No related merge requests found
......@@ -463,12 +463,19 @@ def generate_torus(r_outer, r_inner, n_outer=20, n_inner=10, order=1):
# {{{ generate_box_mesh
def generate_box_mesh(axis_coords, order=1, coord_dtype=np.float64):
def generate_box_mesh(axis_coords, order=1, coord_dtype=np.float64,
group_factory=None):
"""Create a semi-structured mesh.
:param axis_coords: a tuple with a number of entries corresponding
to the number of dimensions, with each entry a numpy array
specifying the coordinates to be used along that axis.
:param group_factory: One of :class:`meshmode.mesh.SimplexElementGroup`
or :class:`meshmode.mesh.TensorProductElementGroup`.
.. versionchanged:: 2017.1
*group_factory* parameter added.
"""
for iaxis, axc in enumerate(axis_coords):
......@@ -492,6 +499,18 @@ def generate_box_mesh(axis_coords, order=1, coord_dtype=np.float64):
vertices = vertices.reshape(dim, -1)
from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup
if group_factory is None:
group_factory = SimplexElementGroup
if issubclass(group_factory, SimplexElementGroup):
is_tp = False
elif issubclass(group_factory, TensorProductElementGroup):
is_tp = True
else:
raise ValueError("unsupported value for 'group_factory': %s"
% group_factory)
el_vertices = []
if dim == 1:
......@@ -516,8 +535,11 @@ def generate_box_mesh(axis_coords, order=1, coord_dtype=np.float64):
c = vertex_indices[i, j+1]
d = vertex_indices[i+1, j+1]
el_vertices.append((a, b, c))
el_vertices.append((d, c, b))
if is_tp:
el_vertices.append((a, b, c, d))
else:
el_vertices.append((a, b, c))
el_vertices.append((d, c, b))
elif dim == 3:
for i in range(shape[0]-1):
......@@ -534,13 +556,19 @@ def generate_box_mesh(axis_coords, order=1, coord_dtype=np.float64):
a110 = vertex_indices[i+1, j+1, k]
a111 = vertex_indices[i+1, j+1, k+1]
el_vertices.append((a000, a100, a010, a001))
el_vertices.append((a101, a100, a001, a010))
el_vertices.append((a101, a011, a010, a001))
if is_tp:
el_vertices.append(
(a000, a001, a010, a011,
a100, a101, a110, a111))
else:
el_vertices.append((a000, a100, a010, a001))
el_vertices.append((a101, a100, a001, a010))
el_vertices.append((a101, a011, a010, a001))
el_vertices.append((a100, a010, a101, a110))
el_vertices.append((a011, a010, a110, a101))
el_vertices.append((a011, a111, a101, a110))
el_vertices.append((a100, a010, a101, a110))
el_vertices.append((a011, a010, a110, a101))
el_vertices.append((a011, a111, a101, a110))
else:
raise NotImplementedError("box meshes of dimension %d"
......@@ -549,7 +577,8 @@ def generate_box_mesh(axis_coords, order=1, coord_dtype=np.float64):
el_vertices = np.array(el_vertices, dtype=np.int32)
grp = make_group_from_vertices(
vertices.reshape(dim, -1), el_vertices, order)
vertices.reshape(dim, -1), el_vertices, order,
group_factory=group_factory)
from meshmode.mesh import Mesh
return Mesh(vertices, [grp],
......
......@@ -857,6 +857,50 @@ def no_test_quad_mesh_3d():
# }}}
def test_quad_single_element():
from meshmode.mesh.generation import make_group_from_vertices
from meshmode.mesh import Mesh, TensorProductElementGroup
vertices = np.array([
[0.91, 1.10],
[2.64, 1.27],
[0.97, 2.56],
[3.00, 3.41],
]).T
mg = make_group_from_vertices(
vertices,
np.array([[0, 1, 2, 3]], dtype=np.int32),
30, group_factory=TensorProductElementGroup)
Mesh(vertices, [mg], nodal_adjacency=None, facial_adjacency_groups=None)
if 0:
import matplotlib.pyplot as plt
plt.plot(
mg.nodes[0].reshape(-1),
mg.nodes[1].reshape(-1), "o")
plt.show()
def test_quad_multi_element():
from meshmode.mesh.generation import generate_box_mesh
from meshmode.mesh import TensorProductElementGroup
mesh = generate_box_mesh(
(
np.linspace(3, 8, 4),
np.linspace(3, 8, 4),
np.linspace(3, 8, 4),
),
10, group_factory=TensorProductElementGroup)
if 0:
import matplotlib.pyplot as plt
mg = mesh.groups[0]
plt.plot(
mg.nodes[0].reshape(-1),
mg.nodes[1].reshape(-1), "o")
plt.show()
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment