diff --git a/meshmode/__init__.py b/meshmode/__init__.py index 1c38d87894841834007489ac2a4eea5175ba1428..9512b4a25aec6e3450e0a0bd2fc7ecc72452ca41 100644 --- a/meshmode/__init__.py +++ b/meshmode/__init__.py @@ -26,8 +26,11 @@ THE SOFTWARE. __doc__ = """ .. exception:: Error .. exception:: DataUnavailable +.. exception:: FileExistsError """ +import six + class Error(RuntimeError): pass @@ -35,3 +38,10 @@ class Error(RuntimeError): class DataUnavailable(Error): pass + + +if six.PY3: + from builtins import FileExistsError +else: + class FileExistsError(OSError): + pass diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index 704ef1bdc8c320a9b6055275f8fd83deae69a9f5..2352081e6555fc9dd2b63b69c81d870c7f9e103d 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -275,8 +275,10 @@ class Visualizer(object): # {{{ vtk - def write_vtk_file(self, file_name, names_and_fields, compressor=None, - real_only=False): + def write_vtk_file(self, file_name, names_and_fields, + compressor=None, + real_only=False, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, @@ -333,10 +335,13 @@ class Visualizer(object): grid.add_pointdata(DataArray(name, field, vector_format=VF_LIST_OF_COMPONENTS)) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + from meshmode import FileExistsError + if os.path.exists(file_name): + if overwrite: + os.remove(file_name) + else: + raise FileExistsError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) @@ -393,7 +398,9 @@ def draw_curve(discr): # {{{ adjacency -def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): +def write_nodal_adjacency_vtk_file(file_name, mesh, + compressor=None, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, AppendedDataXMLGenerator, @@ -430,10 +437,13 @@ def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): cell_types=np.asarray([VTK_LINE] * nconnections, dtype=np.uint8)) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + from meshmode import FileExistsError + if os.path.exists(file_name): + if overwrite: + os.remove(file_name) + else: + raise FileExistsError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py index a4d6facd061960c56fe3030fa801072afc461d17..53da2663ca76b5499456279a4cabb45cc963141c 100644 --- a/meshmode/mesh/visualization.py +++ b/meshmode/mesh/visualization.py @@ -171,7 +171,9 @@ def draw_curve(mesh, # {{{ write_vtk_file -def write_vertex_vtk_file(mesh, file_name, compressor=None): +def write_vertex_vtk_file(mesh, file_name, + compressor=None, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, AppendedDataXMLGenerator, @@ -223,10 +225,13 @@ def write_vertex_vtk_file(mesh, file_name, compressor=None): for vgrp in mesh.groups]), cell_types=cell_types) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + from meshmode import FileExistsError + if os.path.exists(file_name): + if overwrite: + os.remove(file_name) + else: + raise FileExistsError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 9e8791492d0a978cbd8f0acf0b928a8e6265dd48..e97d8efa036f378272b7f94588f1e867265d9a35 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1040,6 +1040,54 @@ def test_quad_multi_element(): plt.show() +def test_vtk_overwrite(ctx_getter): + pytest.importorskip("pyvisfile") + + def _try_write_vtk(writer, obj): + import os + from meshmode import FileExistsError + + filename = "test_vtk_overwrite.vtu" + if os.path.exists(filename): + os.remove(filename) + + writer(filename, []) + with pytest.raises(FileExistsError): + writer(filename, []) + + writer(filename, [], overwrite=True) + if os.path.exists(filename): + os.remove(filename) + + ctx = ctx_getter() + queue = cl.CommandQueue(ctx) + + target_order = 7 + + from meshmode.mesh.generation import generate_torus + mesh = generate_torus(10.0, 2.0, order=target_order) + + from meshmode.discretization import Discretization + from meshmode.discretization.poly_element import \ + InterpolatoryQuadratureSimplexGroupFactory + discr = Discretization( + queue.context, mesh, + InterpolatoryQuadratureSimplexGroupFactory(target_order)) + + from meshmode.discretization.visualization import make_visualizer + from meshmode.discretization.visualization import \ + write_nodal_adjacency_vtk_file + from meshmode.mesh.visualization import write_vertex_vtk_file + + vis = make_visualizer(queue, discr, 1) + _try_write_vtk(vis.write_vtk_file, discr) + + _try_write_vtk(lambda x, y, **kwargs: + write_vertex_vtk_file(discr.mesh, x, **kwargs), discr.mesh) + _try_write_vtk(lambda x, y, **kwargs: + write_nodal_adjacency_vtk_file(x, discr.mesh, **kwargs), discr.mesh) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: