diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index 032b34189ed0761c07dc47b2c57622ab01e448c7..19f4f44c9fa4e9ddb88f3e0f5701cee4da071910 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -93,7 +93,7 @@ class _VisConnectivityGroup(Record): """ .. attribute:: vis_connectivity - an array of shape ``(group.nelements,nsubelements,primitive_element_size)`` + an array of shape ``(group.nelements, nsubelements, primitive_element_size)`` .. attribute:: vtk_cell_type @@ -121,16 +121,15 @@ class _VisConnectivityGroup(Record): # {{{ vtk visualizers -class VTKVisualizer(object): - """ - .. automethod:: write_vtk_file - """ +class VTKConnectivity: + """Connectivity for standard linear VTK element types. - def __init__(self, connection, element_shrink_factor=None): - if element_shrink_factor is None: - element_shrink_factor = 1.0 + .. attribute:: version + .. attribute:: cells + .. attribute:: groups + """ - self.element_shrink_factor = element_shrink_factor + def __init__(self, connection): self.connection = connection self.discr = connection.from_discr self.vis_discr = connection.to_discr @@ -139,8 +138,6 @@ class VTKVisualizer(object): def version(self): return "0.1" - # {{{ connectivity - @property def simplex_cell_types(self): import pyvisfile.vtk as vtk @@ -213,17 +210,9 @@ class VTKVisualizer(object): assert len(node_tuples) == grp.nunit_dofs return el_connectivity, vtk_cell_type - @memoize_method - def vis_nodes_numpy(self): - actx = self.vis_discr._setup_actx - return np.array([ - actx.to_numpy(flatten(thaw(actx, ary))) - for ary in self.vis_discr.nodes() - ]) - @property @memoize_method - def vtk_cells(self): + def cells(self): return np.hstack([ vgrp.vis_connectivity.reshape(-1) for vgrp in self.groups ]) @@ -261,94 +250,10 @@ class VTKVisualizer(object): return result - # }}} - - def write_vtk_file(self, file_name, names_and_fields, - compressor=None, real_only=False, overwrite=False): - from pyvisfile.vtk import ( - UnstructuredGrid, DataArray, - AppendedDataXMLGenerator, - VF_LIST_OF_COMPONENTS) - nodes = self.vis_nodes_numpy() - names_and_fields = [ - (name, resample_to_numpy(self.connection, fld)) - for name, fld in names_and_fields] +class VTKLagrangeConnectivity(VTKConnectivity): + """Connectivity for high-order Lagrange elements.""" - # {{{ create cell_types - - nsubelements = sum(vgrp.nsubelements for vgrp in self.groups) - cell_types = np.empty(nsubelements, dtype=np.uint8) - cell_types.fill(255) - - for vgrp in self.groups: - isubelements = np.s_[ - vgrp.subelement_nr_base: - vgrp.subelement_nr_base + vgrp.nsubelements] - cell_types[isubelements] = vgrp.vtk_cell_type - - assert (cell_types < 255).all() - - # }}} - - # {{{ shrink elements - - if abs(self.element_shrink_factor - 1.0) > 1.0e-14: - node_nr_base = 0 - for vgrp in self.vis_discr.groups: - nodes_view = ( - nodes[:, node_nr_base:node_nr_base + vgrp.ndofs] - .reshape(nodes.shape[0], vgrp.nelements, vgrp.nunit_dofs)) - - el_centers = np.mean(nodes_view, axis=-1) - nodes_view[:] = ( - (self.element_shrink_factor * nodes_view) - + (1-self.element_shrink_factor) - * el_centers[:, :, np.newaxis]) - - node_nr_base += vgrp.ndofs - - # }}} - - # {{{ create grid - - nodes = nodes.reshape(self.vis_discr.ambient_dim, -1) - points = DataArray("points", nodes, - vector_format=VF_LIST_OF_COMPONENTS) - - grid = UnstructuredGrid( - (nodes.shape[1], points), - cells=self.vtk_cells, - cell_types=cell_types) - - for name, field in separate_by_real_and_imag(names_and_fields, real_only): - grid.add_pointdata( - DataArray(name, field, vector_format=VF_LIST_OF_COMPONENTS) - ) - - # }}} - - # {{{ write - - import os - from meshmode import FileExistsError - if os.path.exists(file_name): - if overwrite: - os.remove(file_name) - else: - raise FileExistsError("output file '%s' already exists" % file_name) - - with open(file_name, "w") as outf: - generator = AppendedDataXMLGenerator( - compressor=compressor, - vtk_file_version=self.version) - - generator(grid).write(outf) - - # }}} - - -class VTKLagrangeVisualizer(VTKVisualizer): @property def version(self): # NOTE: version 2.2 has an updated ordering for the hexahedron @@ -382,7 +287,8 @@ class VTKLagrangeVisualizer(VTKVisualizer): vtk_lagrange_simplex_node_tuples, vtk_lagrange_simplex_node_tuples_to_permutation) - node_tuples = vtk_lagrange_simplex_node_tuples(grp.dim, grp.order) + node_tuples = vtk_lagrange_simplex_node_tuples( + grp.dim, grp.order, is_consistent=True) el_connectivity = np.array( vtk_lagrange_simplex_node_tuples_to_permutation(node_tuples), dtype=np.intp).reshape(1, 1, -1) @@ -398,7 +304,7 @@ class VTKLagrangeVisualizer(VTKVisualizer): @property @memoize_method - def vtk_cells(self): + def cells(self): connectivity = np.hstack([ grp.vis_connectivity.reshape(-1) for grp in self.groups @@ -436,18 +342,22 @@ class Visualizer(object): """ def __init__(self, connection, - element_shrink_factor=None, - use_high_order_vtk=False): + element_shrink_factor=None): self.connection = connection self.discr = connection.from_discr self.vis_discr = connection.to_discr - if use_high_order_vtk: - self.vtk = VTKLagrangeVisualizer(connection, - element_shrink_factor=element_shrink_factor) - else: - self.vtk = VTKVisualizer(connection, - element_shrink_factor=element_shrink_factor) + if element_shrink_factor is None: + element_shrink_factor = 1.0 + self.element_shrink_factor = element_shrink_factor + + @memoize_method + def _vis_nodes_numpy(self): + actx = self.vis_discr._setup_actx + return np.array([ + actx.to_numpy(flatten(thaw(actx, ary))) + for ary in self.vis_discr.nodes() + ]) # {{{ mayavi @@ -456,13 +366,11 @@ class Visualizer(object): do_show = kwargs.pop("do_show", True) - nodes = self.vtk.vis_nodes_numpy() + nodes = self._vis_nodes_numpy() field = resample_to_numpy(self.connection, field) assert nodes.shape[0] == self.vis_discr.ambient_dim - #mlab.points3d(nodes[0], nodes[1], 0*nodes[0]) - - vis_connectivity, = self._vis_connectivity() + vis_connectivity = self._vtk_connectivity.groups[0].vis_connectivity if self.vis_discr.dim == 1: nodes = list(nodes) @@ -502,14 +410,109 @@ class Visualizer(object): # {{{ vtk + @property + @memoize_method + def _vtk_connectivity(self): + return VTKConnectivity(self.connection) + + @property + @memoize_method + def _vtk_lagrange_connectivity(self): + return VTKLagrangeConnectivity(self.connection) + def write_vtk_file(self, file_name, names_and_fields, compressor=None, real_only=False, - overwrite=False): - self.vtk.write_vtk_file(file_name, names_and_fields, - compressor=compressor, - real_only=real_only, - overwrite=overwrite) + overwrite=False, + use_lagrange_elements=False): + from pyvisfile.vtk import ( + UnstructuredGrid, DataArray, + AppendedDataXMLGenerator, + VF_LIST_OF_COMPONENTS) + + if use_lagrange_elements: + connectivity = self._vtk_lagrange_connectivity + else: + connectivity = self._vtk_connectivity + + nodes = self._vis_nodes_numpy() + names_and_fields = [ + (name, resample_to_numpy(self.connection, fld)) + for name, fld in names_and_fields] + + # {{{ create cell_types + + nsubelements = sum(vgrp.nsubelements for vgrp in connectivity.groups) + cell_types = np.empty(nsubelements, dtype=np.uint8) + cell_types.fill(255) + + for vgrp in connectivity.groups: + isubelements = np.s_[ + vgrp.subelement_nr_base: + vgrp.subelement_nr_base + vgrp.nsubelements] + cell_types[isubelements] = vgrp.vtk_cell_type + + assert (cell_types < 255).all() + + # }}} + + # {{{ shrink elements + + if abs(self.element_shrink_factor - 1.0) > 1.0e-14: + node_nr_base = 0 + for vgrp in self.vis_discr.groups: + nodes_view = ( + nodes[:, node_nr_base:node_nr_base + vgrp.ndofs] + .reshape(nodes.shape[0], vgrp.nelements, vgrp.nunit_dofs)) + + el_centers = np.mean(nodes_view, axis=-1) + nodes_view[:] = ( + (self.element_shrink_factor * nodes_view) + + (1-self.element_shrink_factor) + * el_centers[:, :, np.newaxis]) + + node_nr_base += vgrp.ndofs + + # }}} + + # {{{ create grid + + nodes = nodes.reshape(self.vis_discr.ambient_dim, -1) + points = DataArray("points", nodes, + vector_format=VF_LIST_OF_COMPONENTS) + + grid = UnstructuredGrid( + (nodes.shape[1], points), + cells=connectivity.cells, + cell_types=cell_types) + + for name, field in separate_by_real_and_imag(names_and_fields, real_only): + grid.add_pointdata( + DataArray(name, field, vector_format=VF_LIST_OF_COMPONENTS) + ) + + # }}} + + # {{{ write + + import os + from meshmode import FileExistsError + if os.path.exists(file_name): + if overwrite: + os.remove(file_name) + else: + raise FileExistsError("output file '%s' already exists" % file_name) + + with open(file_name, "w") as outf: + generator = AppendedDataXMLGenerator( + compressor=compressor, + vtk_file_version=connectivity.version) + + generator(grid).write(outf) + + # }}} + + # }}} # {{{ matplotlib 3D @@ -524,12 +527,12 @@ class Visualizer(object): vmax = kwargs.pop("vmax", None) norm = kwargs.pop("norm", None) - nodes = self.vtk.vis_nodes_numpy() - field = resample_to_numpy(field) + nodes = self._vis_nodes_numpy() + field = resample_to_numpy(self.connection, field) assert nodes.shape[0] == self.vis_discr.ambient_dim - vis_connectivity, = self._vis_connectivity() + vis_connectivity, = self._vtk_connectivity.groups fig = plt.gcf() ax = fig.gca(projection="3d") @@ -580,18 +583,12 @@ class Visualizer(object): # }}} -def make_visualizer(actx, discr, vis_order, - element_shrink_factor=None, use_high_order_vtk=False): +def make_visualizer(actx, discr, vis_order, element_shrink_factor=None): from meshmode.discretization import Discretization - if use_high_order_vtk: - from meshmode.discretization.poly_element import ( - PolynomialEquidistantSimplexElementGroup as SimplexElementGroup, - EquidistantTensorProductElementGroup as TensorElementGroup) - else: - from meshmode.discretization.poly_element import ( - PolynomialWarpAndBlendElementGroup as SimplexElementGroup, - LegendreGaussLobattoTensorProductElementGroup as TensorElementGroup) + from meshmode.discretization.poly_element import ( + PolynomialEquidistantSimplexElementGroup as SimplexElementGroup, + EquidistantTensorProductElementGroup as TensorElementGroup) from meshmode.discretization.poly_element import OrderAndTypeBasedGroupFactory vis_discr = Discretization( @@ -607,8 +604,7 @@ def make_visualizer(actx, discr, vis_order, return Visualizer( make_same_mesh_connection(actx, vis_discr, discr), - element_shrink_factor=element_shrink_factor, - use_high_order_vtk=use_high_order_vtk) + element_shrink_factor=element_shrink_factor) # }}} diff --git a/test/test_meshmode.py b/test/test_meshmode.py index bfb6b53961e09c6aca98659b64f8a89dd0741baa..cabbd593904bc82f5abd44cc4079279ccdf99c4c 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -84,6 +84,61 @@ def test_circle_mesh(visualize=False): # }}} +# {{{ test visualizer + +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_visualizers(ctx_factory, dim): + logging.basicConfig(level=logging.INFO) + + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + actx = PyOpenCLArrayContext(queue) + + nelements = 64 + target_order = 4 + + if dim == 1: + mesh = mgen.make_curve_mesh( + mgen.NArmedStarfish(5, 0.25), + np.linspace(0.0, 1.0, nelements + 1), + target_order) + elif dim == 2: + mesh = mgen.generate_torus(5.0, 1.0, order=target_order) + elif dim == 3: + mesh = mgen.generate_warped_rect_mesh(dim, target_order, 5) + else: + raise ValueError("unknown dimensionality") + + from meshmode.discretization import Discretization + discr = Discretization(actx, mesh, + InterpolatoryQuadratureSimplexGroupFactory(target_order)) + + from meshmode.discretization.visualization import make_visualizer + vis = make_visualizer(actx, discr, target_order) + + vis.write_vtk_file(f"visualizer_vtk_lagrange_{dim}.vtu", [], + use_lagrange_elements=True, overwrite=True) + vis.write_vtk_file(f"visualizer_vtk_linear_{dim}.vtu", [], + use_lagrange_elements=False, overwrite=True) + + if mesh.dim <= 2: + field = thaw(actx, discr.nodes()[0]) + + if mesh.dim == 2: + try: + vis.show_scalar_in_matplotlib_3d(field, do_show=False) + except ImportError: + logger.info("matplotlib not available") + + if mesh.dim <= 2: + try: + vis.show_scalar_in_mayavi(field, do_show=False) + except ImportError: + logger.info("mayavi not avaiable") + +# }}} + + # {{{ test boundary tags def test_boundary_tags():