diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py index 52711a0a8b3654a0ea89291c47a7c14ab19d8ed0..a1b1470739b0f06e6f6f3c65f20b684f4006c8e9 100644 --- a/meshmode/mesh/visualization.py +++ b/meshmode/mesh/visualization.py @@ -142,14 +142,23 @@ def draw_2d_mesh(mesh, draw_vertex_numbers=True, draw_element_numbers=True, # {{{ draw_curve -def draw_curve(mesh): - import matplotlib.pyplot as pt - pt.plot(mesh.vertices[0], mesh.vertices[1], "o") +def draw_curve(mesh, + el_bdry_style="o", el_bdry_kwargs=None, + node_style="x-", node_kwargs=None): + import matplotlib.pyplot as plt + + if el_bdry_kwargs is None: + el_bdry_kwargs = {} + if node_kwargs is None: + node_kwargs = {} + + plt.plot(mesh.vertices[0], mesh.vertices[1], el_bdry_style, **el_bdry_kwargs) for i, group in enumerate(mesh.groups): - pt.plot( + plt.plot( group.nodes[0].ravel(), - group.nodes[1].ravel(), "-x", label="Group %d" % i) + group.nodes[1].ravel(), node_style, label="Group %d" % i, + **node_kwargs) # }}}