diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py index 299f9ef89816dab6362af39592c55a0e7a4a77f4..377f86c98b37d72de34cc8f8a1bbbc0f7ffa53b6 100644 --- a/meshmode/mesh/__init__.py +++ b/meshmode/mesh/__init__.py @@ -41,6 +41,8 @@ __doc__ = """ .. autoclass:: ElementConnectivity +.. autofunction:: as_python + """ @@ -487,4 +489,66 @@ def _compute_connectivity_from_vertices(mesh): # }}} + +# {{{ as_python + +def _numpy_array_as_python(array): + return "np.array(%s, dtype=np.%s)" % ( + repr(array.tolist()), + array.dtype.name) + + +def as_python(mesh, function_name="make_mesh"): + """Return a snippet of Python code (as a string) that will + recreate the mesh given as an input parameter. + """ + + from pytools.py_codegen import PythonCodeGenerator, Indentation + cg = PythonCodeGenerator() + cg("# generated by meshmode.mesh.as_python") + cg("") + cg("import numpy as np") + cg("from meshmode.mesh import Mesh, MeshElementGroup") + cg("") + + cg("def %s():" % function_name) + with Indentation(cg): + cg("vertices = " + _numpy_array_as_python(mesh.vertices)) + cg("") + cg("groups = []") + cg("") + for group in mesh.groups: + cg("import %s" % type(group).__module__) + cg("groups.append(%s.%s(" % ( + type(group).__module__, + type(group).__name__)) + cg(" order=%s," % group.order) + cg(" vertex_indices=%s," + % _numpy_array_as_python(group.vertex_indices)) + cg(" nodes=%s," + % _numpy_array_as_python(group.nodes)) + cg(" unit_nodes=%s))" + % _numpy_array_as_python(group.unit_nodes)) + + cg("return Mesh(vertices, groups, skip_tests=True,") + cg(" vertex_id_dtype=np.%s," % mesh.vertex_id_dtype.name) + cg(" element_id_dtype=np.%s," % mesh.element_id_dtype.name) + + if isinstance(mesh._element_connectivity, ElementConnectivity): + el_con_str = "(%s, %s)" % ( + _numpy_array_as_python( + mesh._element_connectivity.neighbors_starts), + _numpy_array_as_python( + mesh._element_connectivity.neighbors), + ) + else: + el_con_str = repr(mesh._element_connectivity) + + cg(" element_connectivity=%s)" % el_con_str) + + return cg.get() + +# }}} + + # vim: foldmethod=marker diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 6047113c956d1f616c57fc6610b8a75393b90db0..25b26dabe409cfd91612d52c2e1e8840c3f4ce42 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1,6 +1,4 @@ -from __future__ import division -from __future__ import absolute_import -from __future__ import print_function +from __future__ import division, absolute_import, print_function from six.moves import range __copyright__ = "Copyright (C) 2014 Andreas Kloeckner" @@ -356,6 +354,24 @@ def test_rect_mesh(do_plot=False): pt.show() +def test_as_python(): + from meshmode.mesh.generation import make_curve_mesh, cloverleaf + mesh = make_curve_mesh(cloverleaf, np.linspace(0, 1, 100), order=3) + + mesh.element_connectivity + + from meshmode.mesh import as_python + code = as_python(mesh) + + print(code) + exec_dict = {} + exec(compile(code, "gen_code.py", "exec"), exec_dict) + + mesh_2 = exec_dict["make_mesh"]() + + assert mesh == mesh_2 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: