From b3abd7df14a2c860bf4efdfb3f5134b94c715a6b Mon Sep 17 00:00:00 2001 From: Alex Fikl Date: Thu, 21 Jun 2018 16:58:12 -0500 Subject: [PATCH] add a test to check that overwriting works --- meshmode/discretization/visualization.py | 16 ++++--- meshmode/mesh/visualization.py | 2 +- test/test_meshmode.py | 55 ++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index bf6e8a48..72abb047 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -338,7 +338,7 @@ class Visualizer(object): import os if os.path.exists(file_name): if overwrite: - os.path.remove(file_name) + os.remove(file_name) else: raise RuntimeError("output file '%s' already exists" % file_name) @@ -397,7 +397,9 @@ def draw_curve(discr): # {{{ adjacency -def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): +def write_nodal_adjacency_vtk_file(file_name, mesh, + compressor=None, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, AppendedDataXMLGenerator, @@ -434,10 +436,12 @@ def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): cell_types=np.asarray([VTK_LINE] * nconnections, dtype=np.uint8)) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + if os.path.exists(file_name): + if overwrite: + os.remove(file_name) + else: + raise RuntimeError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py index 675f9d90..d9191c57 100644 --- a/meshmode/mesh/visualization.py +++ b/meshmode/mesh/visualization.py @@ -228,7 +228,7 @@ def write_vertex_vtk_file(mesh, file_name, import os if os.path.exists(file_name): if overwrite: - os.path.remove(file_name) + os.remove(file_name) else: raise RuntimeError("output file '%s' already exists" % file_name) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 9e879149..31edf747 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1040,6 +1040,61 @@ def test_quad_multi_element(): plt.show() +def test_vtk_overwrite(ctx_getter): + def _try_write_vtk(writer, obj): + import os + filename = "test_vtk_overwrite.vtu" + if os.path.exists(filename): + os.remove(filename) + + writer(filename, []) + try: + writer(filename, []) + runtime_error = False + except RuntimeError: + print("file cannot be overwritten") + runtime_error = True + assert runtime_error + + try: + writer(filename, [], overwrite=True) + print("file overwritten") + runtime_error = False + except RuntimeError: + runtime_error = True + assert not runtime_error + + os.remove(filename) + + ctx = ctx_getter() + queue = cl.CommandQueue(ctx) + + target_order = 7 + + from meshmode.mesh.generation import generate_torus + mesh = generate_torus(10.0, 2.0, order=target_order) + + from meshmode.discretization import Discretization + from meshmode.discretization.poly_element import \ + InterpolatoryQuadratureSimplexGroupFactory + discr = Discretization( + queue.context, mesh, + InterpolatoryQuadratureSimplexGroupFactory(target_order)) + + from meshmode.discretization.visualization import make_visualizer + from meshmode.discretization.visualization import \ + write_nodal_adjacency_vtk_file + from meshmode.mesh.visualization import write_vertex_vtk_file + + vis = make_visualizer(queue, discr, 1) + _try_write_vtk(vis.write_vtk_file, discr) + + _try_write_vtk(lambda x, y, **kwargs: + write_vertex_vtk_file(discr.mesh, x, **kwargs), discr.mesh) + _try_write_vtk(lambda x, y, **kwargs: + write_nodal_adjacency_vtk_file(x, discr.mesh, **kwargs), discr.mesh) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab