diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index bf6e8a48139cc6f532d98053e705ba4724fcefaf..72abb0475facfa548f901cd0e77d607ddfa0a03d 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -338,7 +338,7 @@ class Visualizer(object): import os if os.path.exists(file_name): if overwrite: - os.path.remove(file_name) + os.remove(file_name) else: raise RuntimeError("output file '%s' already exists" % file_name) @@ -397,7 +397,9 @@ def draw_curve(discr): # {{{ adjacency -def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): +def write_nodal_adjacency_vtk_file(file_name, mesh, + compressor=None, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, AppendedDataXMLGenerator, @@ -434,10 +436,12 @@ def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): cell_types=np.asarray([VTK_LINE] * nconnections, dtype=np.uint8)) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + if os.path.exists(file_name): + if overwrite: + os.remove(file_name) + else: + raise RuntimeError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py index 675f9d900a85e506a997f2faf8d991ad3a086ef2..d9191c577e7448a192b0713fe440e6462a575015 100644 --- a/meshmode/mesh/visualization.py +++ b/meshmode/mesh/visualization.py @@ -228,7 +228,7 @@ def write_vertex_vtk_file(mesh, file_name, import os if os.path.exists(file_name): if overwrite: - os.path.remove(file_name) + os.remove(file_name) else: raise RuntimeError("output file '%s' already exists" % file_name) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 9e8791492d0a978cbd8f0acf0b928a8e6265dd48..31edf7479721fd2f248e862b5f799c9e4d353728 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1040,6 +1040,61 @@ def test_quad_multi_element(): plt.show() +def test_vtk_overwrite(ctx_getter): + def _try_write_vtk(writer, obj): + import os + filename = "test_vtk_overwrite.vtu" + if os.path.exists(filename): + os.remove(filename) + + writer(filename, []) + try: + writer(filename, []) + runtime_error = False + except RuntimeError: + print("file cannot be overwritten") + runtime_error = True + assert runtime_error + + try: + writer(filename, [], overwrite=True) + print("file overwritten") + runtime_error = False + except RuntimeError: + runtime_error = True + assert not runtime_error + + os.remove(filename) + + ctx = ctx_getter() + queue = cl.CommandQueue(ctx) + + target_order = 7 + + from meshmode.mesh.generation import generate_torus + mesh = generate_torus(10.0, 2.0, order=target_order) + + from meshmode.discretization import Discretization + from meshmode.discretization.poly_element import \ + InterpolatoryQuadratureSimplexGroupFactory + discr = Discretization( + queue.context, mesh, + InterpolatoryQuadratureSimplexGroupFactory(target_order)) + + from meshmode.discretization.visualization import make_visualizer + from meshmode.discretization.visualization import \ + write_nodal_adjacency_vtk_file + from meshmode.mesh.visualization import write_vertex_vtk_file + + vis = make_visualizer(queue, discr, 1) + _try_write_vtk(vis.write_vtk_file, discr) + + _try_write_vtk(lambda x, y, **kwargs: + write_vertex_vtk_file(discr.mesh, x, **kwargs), discr.mesh) + _try_write_vtk(lambda x, y, **kwargs: + write_nodal_adjacency_vtk_file(x, discr.mesh, **kwargs), discr.mesh) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: