From 5af6016bd7f13db1438ec6630f9f81f6d7740b0c Mon Sep 17 00:00:00 2001 From: Alex Fikl Date: Thu, 21 Jun 2018 10:49:53 -0500 Subject: [PATCH 1/4] visualization: add flag to enable overwriting vtk files --- meshmode/discretization/visualization.py | 16 ++++++++++------ meshmode/mesh/visualization.py | 14 +++++++++----- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index 704ef1bd..bf6e8a48 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -275,8 +275,10 @@ class Visualizer(object): # {{{ vtk - def write_vtk_file(self, file_name, names_and_fields, compressor=None, - real_only=False): + def write_vtk_file(self, file_name, names_and_fields, + compressor=None, + real_only=False, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, @@ -333,10 +335,12 @@ class Visualizer(object): grid.add_pointdata(DataArray(name, field, vector_format=VF_LIST_OF_COMPONENTS)) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + if os.path.exists(file_name): + if overwrite: + os.path.remove(file_name) + else: + raise RuntimeError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py index a4d6facd..675f9d90 100644 --- a/meshmode/mesh/visualization.py +++ b/meshmode/mesh/visualization.py @@ -171,7 +171,9 @@ def draw_curve(mesh, # {{{ write_vtk_file -def write_vertex_vtk_file(mesh, file_name, compressor=None): +def write_vertex_vtk_file(mesh, file_name, + compressor=None, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, AppendedDataXMLGenerator, @@ -223,10 +225,12 @@ def write_vertex_vtk_file(mesh, file_name, compressor=None): for vgrp in mesh.groups]), cell_types=cell_types) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + if os.path.exists(file_name): + if overwrite: + os.path.remove(file_name) + else: + raise RuntimeError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) -- GitLab From b3abd7df14a2c860bf4efdfb3f5134b94c715a6b Mon Sep 17 00:00:00 2001 From: Alex Fikl Date: Thu, 21 Jun 2018 16:58:12 -0500 Subject: [PATCH 2/4] add a test to check that overwriting works --- meshmode/discretization/visualization.py | 16 ++++--- meshmode/mesh/visualization.py | 2 +- test/test_meshmode.py | 55 ++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index bf6e8a48..72abb047 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -338,7 +338,7 @@ class Visualizer(object): import os if os.path.exists(file_name): if overwrite: - os.path.remove(file_name) + os.remove(file_name) else: raise RuntimeError("output file '%s' already exists" % file_name) @@ -397,7 +397,9 @@ def draw_curve(discr): # {{{ adjacency -def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): +def write_nodal_adjacency_vtk_file(file_name, mesh, + compressor=None, + overwrite=False): from pyvisfile.vtk import ( UnstructuredGrid, DataArray, AppendedDataXMLGenerator, @@ -434,10 +436,12 @@ def write_nodal_adjacency_vtk_file(file_name, mesh, compressor=None,): cell_types=np.asarray([VTK_LINE] * nconnections, dtype=np.uint8)) - from os.path import exists - if exists(file_name): - raise RuntimeError("output file '%s' already exists" - % file_name) + import os + if os.path.exists(file_name): + if overwrite: + os.remove(file_name) + else: + raise RuntimeError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py index 675f9d90..d9191c57 100644 --- a/meshmode/mesh/visualization.py +++ b/meshmode/mesh/visualization.py @@ -228,7 +228,7 @@ def write_vertex_vtk_file(mesh, file_name, import os if os.path.exists(file_name): if overwrite: - os.path.remove(file_name) + os.remove(file_name) else: raise RuntimeError("output file '%s' already exists" % file_name) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 9e879149..31edf747 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1040,6 +1040,61 @@ def test_quad_multi_element(): plt.show() +def test_vtk_overwrite(ctx_getter): + def _try_write_vtk(writer, obj): + import os + filename = "test_vtk_overwrite.vtu" + if os.path.exists(filename): + os.remove(filename) + + writer(filename, []) + try: + writer(filename, []) + runtime_error = False + except RuntimeError: + print("file cannot be overwritten") + runtime_error = True + assert runtime_error + + try: + writer(filename, [], overwrite=True) + print("file overwritten") + runtime_error = False + except RuntimeError: + runtime_error = True + assert not runtime_error + + os.remove(filename) + + ctx = ctx_getter() + queue = cl.CommandQueue(ctx) + + target_order = 7 + + from meshmode.mesh.generation import generate_torus + mesh = generate_torus(10.0, 2.0, order=target_order) + + from meshmode.discretization import Discretization + from meshmode.discretization.poly_element import \ + InterpolatoryQuadratureSimplexGroupFactory + discr = Discretization( + queue.context, mesh, + InterpolatoryQuadratureSimplexGroupFactory(target_order)) + + from meshmode.discretization.visualization import make_visualizer + from meshmode.discretization.visualization import \ + write_nodal_adjacency_vtk_file + from meshmode.mesh.visualization import write_vertex_vtk_file + + vis = make_visualizer(queue, discr, 1) + _try_write_vtk(vis.write_vtk_file, discr) + + _try_write_vtk(lambda x, y, **kwargs: + write_vertex_vtk_file(discr.mesh, x, **kwargs), discr.mesh) + _try_write_vtk(lambda x, y, **kwargs: + write_nodal_adjacency_vtk_file(x, discr.mesh, **kwargs), discr.mesh) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab From 0b9503c05b2b6c4fa6339b95fd627b3436674189 Mon Sep 17 00:00:00 2001 From: Alex Fikl Date: Thu, 21 Jun 2018 17:36:51 -0500 Subject: [PATCH 3/4] check for pyvisfile --- test/test_meshmode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 31edf747..05bbe5ad 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1041,6 +1041,8 @@ def test_quad_multi_element(): def test_vtk_overwrite(ctx_getter): + pytest.importorskip("pyvisfile") + def _try_write_vtk(writer, obj): import os filename = "test_vtk_overwrite.vtu" -- GitLab From 4f6a0c6560a7c8e47456249a29fc2a76c4966369 Mon Sep 17 00:00:00 2001 From: Alex Fikl Date: Thu, 21 Jun 2018 20:39:23 -0500 Subject: [PATCH 4/4] use pytest.raises with a custom expression --- meshmode/__init__.py | 10 ++++++++++ meshmode/discretization/visualization.py | 6 ++++-- meshmode/mesh/visualization.py | 3 ++- test/test_meshmode.py | 23 +++++++---------------- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/meshmode/__init__.py b/meshmode/__init__.py index 1c38d878..9512b4a2 100644 --- a/meshmode/__init__.py +++ b/meshmode/__init__.py @@ -26,8 +26,11 @@ THE SOFTWARE. __doc__ = """ .. exception:: Error .. exception:: DataUnavailable +.. exception:: FileExistsError """ +import six + class Error(RuntimeError): pass @@ -35,3 +38,10 @@ class Error(RuntimeError): class DataUnavailable(Error): pass + + +if six.PY3: + from builtins import FileExistsError +else: + class FileExistsError(OSError): + pass diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index 72abb047..2352081e 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -336,11 +336,12 @@ class Visualizer(object): vector_format=VF_LIST_OF_COMPONENTS)) import os + from meshmode import FileExistsError if os.path.exists(file_name): if overwrite: os.remove(file_name) else: - raise RuntimeError("output file '%s' already exists" % file_name) + raise FileExistsError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) @@ -437,11 +438,12 @@ def write_nodal_adjacency_vtk_file(file_name, mesh, dtype=np.uint8)) import os + from meshmode import FileExistsError if os.path.exists(file_name): if overwrite: os.remove(file_name) else: - raise RuntimeError("output file '%s' already exists" % file_name) + raise FileExistsError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/meshmode/mesh/visualization.py b/meshmode/mesh/visualization.py index d9191c57..53da2663 100644 --- a/meshmode/mesh/visualization.py +++ b/meshmode/mesh/visualization.py @@ -226,11 +226,12 @@ def write_vertex_vtk_file(mesh, file_name, cell_types=cell_types) import os + from meshmode import FileExistsError if os.path.exists(file_name): if overwrite: os.remove(file_name) else: - raise RuntimeError("output file '%s' already exists" % file_name) + raise FileExistsError("output file '%s' already exists" % file_name) with open(file_name, "w") as outf: AppendedDataXMLGenerator(compressor)(grid).write(outf) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index 05bbe5ad..e97d8efa 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -1045,28 +1045,19 @@ def test_vtk_overwrite(ctx_getter): def _try_write_vtk(writer, obj): import os + from meshmode import FileExistsError + filename = "test_vtk_overwrite.vtu" if os.path.exists(filename): os.remove(filename) writer(filename, []) - try: + with pytest.raises(FileExistsError): writer(filename, []) - runtime_error = False - except RuntimeError: - print("file cannot be overwritten") - runtime_error = True - assert runtime_error - - try: - writer(filename, [], overwrite=True) - print("file overwritten") - runtime_error = False - except RuntimeError: - runtime_error = True - assert not runtime_error - - os.remove(filename) + + writer(filename, [], overwrite=True) + if os.path.exists(filename): + os.remove(filename) ctx = ctx_getter() queue = cl.CommandQueue(ctx) -- GitLab