diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py
index 032b34189ed0761c07dc47b2c57622ab01e448c7..19f4f44c9fa4e9ddb88f3e0f5701cee4da071910 100644
--- a/meshmode/discretization/visualization.py
+++ b/meshmode/discretization/visualization.py
@@ -93,7 +93,7 @@ class _VisConnectivityGroup(Record):
     """
     .. attribute:: vis_connectivity
 
-        an array of shape ``(group.nelements,nsubelements,primitive_element_size)``
+        an array of shape ``(group.nelements, nsubelements, primitive_element_size)``
 
     .. attribute:: vtk_cell_type
 
@@ -121,16 +121,15 @@ class _VisConnectivityGroup(Record):
 
 # {{{ vtk visualizers
 
-class VTKVisualizer(object):
-    """
-    .. automethod:: write_vtk_file
-    """
+class VTKConnectivity:
+    """Connectivity for standard linear VTK element types.
 
-    def __init__(self, connection, element_shrink_factor=None):
-        if element_shrink_factor is None:
-            element_shrink_factor = 1.0
+    .. attribute:: version
+    .. attribute:: cells
+    .. attribute:: groups
+    """
 
-        self.element_shrink_factor = element_shrink_factor
+    def __init__(self, connection):
         self.connection = connection
         self.discr = connection.from_discr
         self.vis_discr = connection.to_discr
@@ -139,8 +138,6 @@ class VTKVisualizer(object):
     def version(self):
         return "0.1"
 
-    # {{{ connectivity
-
     @property
     def simplex_cell_types(self):
         import pyvisfile.vtk as vtk
@@ -213,17 +210,9 @@ class VTKVisualizer(object):
         assert len(node_tuples) == grp.nunit_dofs
         return el_connectivity, vtk_cell_type
 
-    @memoize_method
-    def vis_nodes_numpy(self):
-        actx = self.vis_discr._setup_actx
-        return np.array([
-            actx.to_numpy(flatten(thaw(actx, ary)))
-            for ary in self.vis_discr.nodes()
-            ])
-
     @property
     @memoize_method
-    def vtk_cells(self):
+    def cells(self):
         return np.hstack([
             vgrp.vis_connectivity.reshape(-1) for vgrp in self.groups
             ])
@@ -261,94 +250,10 @@ class VTKVisualizer(object):
 
         return result
 
-    # }}}
-
-    def write_vtk_file(self, file_name, names_and_fields,
-            compressor=None, real_only=False, overwrite=False):
-        from pyvisfile.vtk import (
-                UnstructuredGrid, DataArray,
-                AppendedDataXMLGenerator,
-                VF_LIST_OF_COMPONENTS)
 
-        nodes = self.vis_nodes_numpy()
-        names_and_fields = [
-                (name, resample_to_numpy(self.connection, fld))
-                for name, fld in names_and_fields]
+class VTKLagrangeConnectivity(VTKConnectivity):
+    """Connectivity for high-order Lagrange elements."""
 
-        # {{{ create cell_types
-
-        nsubelements = sum(vgrp.nsubelements for vgrp in self.groups)
-        cell_types = np.empty(nsubelements, dtype=np.uint8)
-        cell_types.fill(255)
-
-        for vgrp in self.groups:
-            isubelements = np.s_[
-                    vgrp.subelement_nr_base:
-                    vgrp.subelement_nr_base + vgrp.nsubelements]
-            cell_types[isubelements] = vgrp.vtk_cell_type
-
-        assert (cell_types < 255).all()
-
-        # }}}
-
-        # {{{ shrink elements
-
-        if abs(self.element_shrink_factor - 1.0) > 1.0e-14:
-            node_nr_base = 0
-            for vgrp in self.vis_discr.groups:
-                nodes_view = (
-                        nodes[:, node_nr_base:node_nr_base + vgrp.ndofs]
-                        .reshape(nodes.shape[0], vgrp.nelements, vgrp.nunit_dofs))
-
-                el_centers = np.mean(nodes_view, axis=-1)
-                nodes_view[:] = (
-                        (self.element_shrink_factor * nodes_view)
-                        + (1-self.element_shrink_factor)
-                        * el_centers[:, :, np.newaxis])
-
-                node_nr_base += vgrp.ndofs
-
-        # }}}
-
-        # {{{ create grid
-
-        nodes = nodes.reshape(self.vis_discr.ambient_dim, -1)
-        points = DataArray("points", nodes,
-                vector_format=VF_LIST_OF_COMPONENTS)
-
-        grid = UnstructuredGrid(
-                (nodes.shape[1], points),
-                cells=self.vtk_cells,
-                cell_types=cell_types)
-
-        for name, field in separate_by_real_and_imag(names_and_fields, real_only):
-            grid.add_pointdata(
-                    DataArray(name, field, vector_format=VF_LIST_OF_COMPONENTS)
-                    )
-
-        # }}}
-
-        # {{{ write
-
-        import os
-        from meshmode import FileExistsError
-        if os.path.exists(file_name):
-            if overwrite:
-                os.remove(file_name)
-            else:
-                raise FileExistsError("output file '%s' already exists" % file_name)
-
-        with open(file_name, "w") as outf:
-            generator = AppendedDataXMLGenerator(
-                    compressor=compressor,
-                    vtk_file_version=self.version)
-
-            generator(grid).write(outf)
-
-        # }}}
-
-
-class VTKLagrangeVisualizer(VTKVisualizer):
     @property
     def version(self):
         # NOTE: version 2.2 has an updated ordering for the hexahedron
@@ -382,7 +287,8 @@ class VTKLagrangeVisualizer(VTKVisualizer):
                     vtk_lagrange_simplex_node_tuples,
                     vtk_lagrange_simplex_node_tuples_to_permutation)
 
-            node_tuples = vtk_lagrange_simplex_node_tuples(grp.dim, grp.order)
+            node_tuples = vtk_lagrange_simplex_node_tuples(
+                    grp.dim, grp.order, is_consistent=True)
             el_connectivity = np.array(
                     vtk_lagrange_simplex_node_tuples_to_permutation(node_tuples),
                     dtype=np.intp).reshape(1, 1, -1)
@@ -398,7 +304,7 @@ class VTKLagrangeVisualizer(VTKVisualizer):
 
     @property
     @memoize_method
-    def vtk_cells(self):
+    def cells(self):
         connectivity = np.hstack([
             grp.vis_connectivity.reshape(-1)
             for grp in self.groups
@@ -436,18 +342,22 @@ class Visualizer(object):
     """
 
     def __init__(self, connection,
-            element_shrink_factor=None,
-            use_high_order_vtk=False):
+            element_shrink_factor=None):
         self.connection = connection
         self.discr = connection.from_discr
         self.vis_discr = connection.to_discr
 
-        if use_high_order_vtk:
-            self.vtk = VTKLagrangeVisualizer(connection,
-                    element_shrink_factor=element_shrink_factor)
-        else:
-            self.vtk = VTKVisualizer(connection,
-                    element_shrink_factor=element_shrink_factor)
+        if element_shrink_factor is None:
+            element_shrink_factor = 1.0
+        self.element_shrink_factor = element_shrink_factor
+
+    @memoize_method
+    def _vis_nodes_numpy(self):
+        actx = self.vis_discr._setup_actx
+        return np.array([
+            actx.to_numpy(flatten(thaw(actx, ary)))
+            for ary in self.vis_discr.nodes()
+            ])
 
     # {{{ mayavi
 
@@ -456,13 +366,11 @@ class Visualizer(object):
 
         do_show = kwargs.pop("do_show", True)
 
-        nodes = self.vtk.vis_nodes_numpy()
+        nodes = self._vis_nodes_numpy()
         field = resample_to_numpy(self.connection, field)
 
         assert nodes.shape[0] == self.vis_discr.ambient_dim
-        #mlab.points3d(nodes[0], nodes[1], 0*nodes[0])
-
-        vis_connectivity, = self._vis_connectivity()
+        vis_connectivity = self._vtk_connectivity.groups[0].vis_connectivity
 
         if self.vis_discr.dim == 1:
             nodes = list(nodes)
@@ -502,14 +410,109 @@ class Visualizer(object):
 
     # {{{ vtk
 
+    @property
+    @memoize_method
+    def _vtk_connectivity(self):
+        return VTKConnectivity(self.connection)
+
+    @property
+    @memoize_method
+    def _vtk_lagrange_connectivity(self):
+        return VTKLagrangeConnectivity(self.connection)
+
     def write_vtk_file(self, file_name, names_and_fields,
                        compressor=None,
                        real_only=False,
-                       overwrite=False):
-        self.vtk.write_vtk_file(file_name, names_and_fields,
-                compressor=compressor,
-                real_only=real_only,
-                overwrite=overwrite)
+                       overwrite=False,
+                       use_lagrange_elements=False):
+        from pyvisfile.vtk import (
+                UnstructuredGrid, DataArray,
+                AppendedDataXMLGenerator,
+                VF_LIST_OF_COMPONENTS)
+
+        if use_lagrange_elements:
+            connectivity = self._vtk_lagrange_connectivity
+        else:
+            connectivity = self._vtk_connectivity
+
+        nodes = self._vis_nodes_numpy()
+        names_and_fields = [
+                (name, resample_to_numpy(self.connection, fld))
+                for name, fld in names_and_fields]
+
+        # {{{ create cell_types
+
+        nsubelements = sum(vgrp.nsubelements for vgrp in connectivity.groups)
+        cell_types = np.empty(nsubelements, dtype=np.uint8)
+        cell_types.fill(255)
+
+        for vgrp in connectivity.groups:
+            isubelements = np.s_[
+                    vgrp.subelement_nr_base:
+                    vgrp.subelement_nr_base + vgrp.nsubelements]
+            cell_types[isubelements] = vgrp.vtk_cell_type
+
+        assert (cell_types < 255).all()
+
+        # }}}
+
+        # {{{ shrink elements
+
+        if abs(self.element_shrink_factor - 1.0) > 1.0e-14:
+            node_nr_base = 0
+            for vgrp in self.vis_discr.groups:
+                nodes_view = (
+                        nodes[:, node_nr_base:node_nr_base + vgrp.ndofs]
+                        .reshape(nodes.shape[0], vgrp.nelements, vgrp.nunit_dofs))
+
+                el_centers = np.mean(nodes_view, axis=-1)
+                nodes_view[:] = (
+                        (self.element_shrink_factor * nodes_view)
+                        + (1-self.element_shrink_factor)
+                        * el_centers[:, :, np.newaxis])
+
+                node_nr_base += vgrp.ndofs
+
+        # }}}
+
+        # {{{ create grid
+
+        nodes = nodes.reshape(self.vis_discr.ambient_dim, -1)
+        points = DataArray("points", nodes,
+                vector_format=VF_LIST_OF_COMPONENTS)
+
+        grid = UnstructuredGrid(
+                (nodes.shape[1], points),
+                cells=connectivity.cells,
+                cell_types=cell_types)
+
+        for name, field in separate_by_real_and_imag(names_and_fields, real_only):
+            grid.add_pointdata(
+                    DataArray(name, field, vector_format=VF_LIST_OF_COMPONENTS)
+                    )
+
+        # }}}
+
+        # {{{ write
+
+        import os
+        from meshmode import FileExistsError
+        if os.path.exists(file_name):
+            if overwrite:
+                os.remove(file_name)
+            else:
+                raise FileExistsError("output file '%s' already exists" % file_name)
+
+        with open(file_name, "w") as outf:
+            generator = AppendedDataXMLGenerator(
+                    compressor=compressor,
+                    vtk_file_version=connectivity.version)
+
+            generator(grid).write(outf)
+
+        # }}}
+
+    # }}}
 
     # {{{ matplotlib 3D
 
@@ -524,12 +527,12 @@ class Visualizer(object):
         vmax = kwargs.pop("vmax", None)
         norm = kwargs.pop("norm", None)
 
-        nodes = self.vtk.vis_nodes_numpy()
-        field = resample_to_numpy(field)
+        nodes = self._vis_nodes_numpy()
+        field = resample_to_numpy(self.connection, field)
 
         assert nodes.shape[0] == self.vis_discr.ambient_dim
 
-        vis_connectivity, = self._vis_connectivity()
+        vis_connectivity, = self._vtk_connectivity.groups
 
         fig = plt.gcf()
         ax = fig.gca(projection="3d")
@@ -580,18 +583,12 @@ class Visualizer(object):
     # }}}
 
 
-def make_visualizer(actx, discr, vis_order,
-        element_shrink_factor=None, use_high_order_vtk=False):
+def make_visualizer(actx, discr, vis_order, element_shrink_factor=None):
     from meshmode.discretization import Discretization
 
-    if use_high_order_vtk:
-        from meshmode.discretization.poly_element import (
-                PolynomialEquidistantSimplexElementGroup as SimplexElementGroup,
-                EquidistantTensorProductElementGroup as TensorElementGroup)
-    else:
-        from meshmode.discretization.poly_element import (
-                PolynomialWarpAndBlendElementGroup as SimplexElementGroup,
-                LegendreGaussLobattoTensorProductElementGroup as TensorElementGroup)
+    from meshmode.discretization.poly_element import (
+            PolynomialEquidistantSimplexElementGroup as SimplexElementGroup,
+            EquidistantTensorProductElementGroup as TensorElementGroup)
 
     from meshmode.discretization.poly_element import OrderAndTypeBasedGroupFactory
     vis_discr = Discretization(
@@ -607,8 +604,7 @@ def make_visualizer(actx, discr, vis_order,
 
     return Visualizer(
             make_same_mesh_connection(actx, vis_discr, discr),
-            element_shrink_factor=element_shrink_factor,
-            use_high_order_vtk=use_high_order_vtk)
+            element_shrink_factor=element_shrink_factor)
 
 # }}}
 
diff --git a/test/test_meshmode.py b/test/test_meshmode.py
index bfb6b53961e09c6aca98659b64f8a89dd0741baa..cabbd593904bc82f5abd44cc4079279ccdf99c4c 100644
--- a/test/test_meshmode.py
+++ b/test/test_meshmode.py
@@ -84,6 +84,61 @@ def test_circle_mesh(visualize=False):
 # }}}
 
 
+# {{{ test visualizer
+
+@pytest.mark.parametrize("dim", [1, 2, 3])
+def test_visualizers(ctx_factory, dim):
+    logging.basicConfig(level=logging.INFO)
+
+    cl_ctx = ctx_factory()
+    queue = cl.CommandQueue(cl_ctx)
+    actx = PyOpenCLArrayContext(queue)
+
+    nelements = 64
+    target_order = 4
+
+    if dim == 1:
+        mesh = mgen.make_curve_mesh(
+                mgen.NArmedStarfish(5, 0.25),
+                np.linspace(0.0, 1.0, nelements + 1),
+                target_order)
+    elif dim == 2:
+        mesh = mgen.generate_torus(5.0, 1.0, order=target_order)
+    elif dim == 3:
+        mesh = mgen.generate_warped_rect_mesh(dim, target_order, 5)
+    else:
+        raise ValueError("unknown dimensionality")
+
+    from meshmode.discretization import Discretization
+    discr = Discretization(actx, mesh,
+            InterpolatoryQuadratureSimplexGroupFactory(target_order))
+
+    from meshmode.discretization.visualization import make_visualizer
+    vis = make_visualizer(actx, discr, target_order)
+
+    vis.write_vtk_file(f"visualizer_vtk_lagrange_{dim}.vtu", [],
+            use_lagrange_elements=True, overwrite=True)
+    vis.write_vtk_file(f"visualizer_vtk_linear_{dim}.vtu", [],
+            use_lagrange_elements=False, overwrite=True)
+
+    if mesh.dim <= 2:
+        field = thaw(actx, discr.nodes()[0])
+
+    if mesh.dim == 2:
+        try:
+            vis.show_scalar_in_matplotlib_3d(field, do_show=False)
+        except ImportError:
+            logger.info("matplotlib not available")
+
+    if mesh.dim <= 2:
+        try:
+            vis.show_scalar_in_mayavi(field, do_show=False)
+        except ImportError:
+            logger.info("mayavi not avaiable")
+
+# }}}
+
+
 # {{{ test boundary tags
 
 def test_boundary_tags():