diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index f7592520ccdf98b8fa6cc8e05c743c9969dea854..730133e25104f8824322e1bd8b17cbbe2bd01fab 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -79,6 +79,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         self.allocator = allocator
         self.array_types = (pt.Array, )
         self._freeze_prg_cache = {}
+        self._dag_transform_cache = {}
 
         # unused, but necessary to keep the context alive
         self.context = self.queue.context
@@ -113,24 +114,56 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         return cl_array.get(queue=self.queue)
 
     def call_loopy(self, program, **kwargs):
-        import pyopencl.array as cla
+        from pytato.scalar_expr import SCALAR_CLASSES
         from pytato.loopy import call_loopy
+        from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
 
         entrypoint = program.default_entrypoint.name
 
-        # thaw frozen arrays
-        kwargs = {kw: (self.thaw(arg) if isinstance(arg, cla.Array) else arg)
-                  for kw, arg in kwargs.items()}
+        # {{{ preprocess args
+
+        processed_kwargs = {}
+
+        for kw, arg in sorted(kwargs.items()):
+            if isinstance(arg, self.array_types + SCALAR_CLASSES):
+                pass
+            elif isinstance(arg, TaggableCLArray):
+                arg = self.thaw(arg)
+            else:
+                raise ValueError(f"call_loopy argument '{kw}' expected to be an"
+                                 " instance of 'pytato.Array', 'Number' or"
+                                 f"'TaggableCLArray', got '{type(arg)}'")
+
+            processed_kwargs[kw] = arg
+
+        # }}}
 
-        return call_loopy(program, kwargs, entrypoint)
+        return call_loopy(program, processed_kwargs, entrypoint)
 
     def freeze(self, array):
         import pytato as pt
         import pyopencl.array as cla
         import loopy as lp
+        from arraycontext.impl.pytato.utils import (_normalize_pt_expr,
+                                                    get_cl_axes_from_pt_axes)
+        from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array,
+                                                                  TaggableCLArray)
 
-        if isinstance(array, cla.Array):
+        if isinstance(array, TaggableCLArray):
             return array.with_queue(None)
+        if isinstance(array, cla.Array):
+            from warnings import warn
+            warn("Freezing pyopencl.array.Array will be deprecated in 2023."
+                 " Use `to_tagged_cl_array` to convert the array to"
+                 " TaggableCLArray", DeprecationWarning, stacklevel=2)
+            return to_tagged_cl_array(array.with_queue(None),
+                                      axes=None,
+                                      tags=frozenset())
+        if isinstance(array, pt.DataWrapper):
+            # trivial freeze.
+            return to_tagged_cl_array(array.data.with_queue(None),
+                                      axes=get_cl_axes_from_pt_axes(array.axes),
+                                      tags=array.tags)
         if not isinstance(array, pt.Array):
             raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
                             f"non-pytato array of type '{type(array)}'")
@@ -138,14 +171,16 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         # {{{ early exit for 0-sized arrays
 
         if array.size == 0:
-            return cla.empty(self.queue.context,
-                             shape=array.shape,
-                             dtype=array.dtype,
-                             allocator=self.allocator)
+            return to_tagged_cl_array(
+                cla.empty(self.queue.context,
+                          shape=array.shape,
+                          dtype=array.dtype,
+                          allocator=self.allocator),
+                get_cl_axes_from_pt_axes(array.axes),
+                array.tags)
 
         # }}}
 
-        from arraycontext.impl.pytato.utils import _normalize_pt_expr
         pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
                 {"_actx_out": array})
 
@@ -155,7 +190,13 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         try:
             pt_prg = self._freeze_prg_cache[normalized_expr]
         except KeyError:
-            pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr),
+            if normalized_expr in self._dag_transform_cache:
+                transformed_dag = self._dag_transform_cache[normalized_expr]
+            else:
+                transformed_dag = self.transform_dag(normalized_expr)
+                self._dag_transform_cache[normalized_expr] = transformed_dag
+
+            pt_prg = pt.generate_loopy(transformed_dag,
                                        options=lp.Options(return_dict=True,
                                                           no_numpy=True),
                                        cl_device=self.queue.device)
@@ -166,17 +207,31 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         evt, out_dict = pt_prg(self.queue, **bound_arguments)
         evt.wait()
 
-        return out_dict["_actx_out"].with_queue(None)
+        return to_tagged_cl_array(
+            out_dict["_actx_out"].with_queue(None),
+            get_cl_axes_from_pt_axes(
+                self._dag_transform_cache[normalized_expr]["_actx_out"].expr.axes),
+            array.tags)
 
     def thaw(self, array):
         import pytato as pt
-        import pyopencl.array as cla
-
-        if not isinstance(array, cla.Array):
-            raise TypeError("PytatoPyOpenCLArrayContext.thaw expects CL arrays, got "
-                    f"{type(array)}")
-
-        return pt.make_data_wrapper(array.with_queue(self.queue))
+        from .utils import get_pt_axes_from_cl_axes
+        from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
+                                                                  to_tagged_cl_array)
+        import pyopencl.array as cl_array
+
+        if isinstance(array, TaggableCLArray):
+            pass
+        elif isinstance(array, cl_array.Array):
+            array = to_tagged_cl_array(array, axes=None, tags=frozenset())
+        else:
+            raise TypeError("PytatoPyOpenCLArrayContext.thaw expects "
+                            "'TaggableCLArray' or 'cl.array.Array' got "
+                            f"{type(array)}.")
+
+        return pt.make_data_wrapper(array.with_queue(self.queue),
+                                    axes=get_pt_axes_from_cl_axes(array.axes),
+                                    tags=array.tags)
 
     # }}}
 
@@ -219,12 +274,23 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
     def einsum(self, spec, *args, arg_names=None, tagged=()):
         import pyopencl.array as cla
         import pytato as pt
+        from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
+                                                                  to_tagged_cl_array)
         if arg_names is None:
             arg_names = (None,) * len(args)
 
         def preprocess_arg(name, arg):
-            if isinstance(arg, cla.Array):
+            if isinstance(arg, TaggableCLArray):
                 ary = self.thaw(arg)
+            elif isinstance(arg, cla.Array):
+                from warnings import warn
+                warn("Passing pyopencl.array.Array to einsum will be "
+                     "deprecated in 2023."
+                     " Use `to_tagged_cl_array` to convert the array to"
+                     " TaggableCLArray.", DeprecationWarning, stacklevel=2)
+                ary = self.thaw(to_tagged_cl_array(arg,
+                                                   axes=None,
+                                                   tags=frozenset()))
             else:
                 assert isinstance(arg, pt.Array)
                 ary = arg
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 71f98a8c5c6ffd5dbb3027082b8cd87d0b6fe146..d83e376a5c10f7f8ea82a7f4547978308af157f6 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -34,7 +34,7 @@ from arraycontext.container.traversal import rec_keyed_map_array_container
 
 import abc
 import numpy as np
-from typing import Any, Callable, Tuple, Dict, Mapping
+from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet
 from dataclasses import dataclass, field
 from pyrsistent import pmap, PMap
 
@@ -169,7 +169,11 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
     elif is_array_container_type(arg.__class__):
         def _rec_to_placeholder(keys, ary):
             name = arg_id_to_name[(kw,) + keys]
-            return pt.make_placeholder(name, ary.shape, ary.dtype)
+            return pt.make_placeholder(name,
+                                       ary.shape,
+                                       ary.dtype,
+                                       axes=ary.axes,
+                                       tags=ary.tags)
 
         return rec_keyed_map_array_container(_rec_to_placeholder, arg)
     else:
@@ -204,6 +208,13 @@ class LazilyCompilingFunctionCaller:
         with ProcessLogger(logger, "transform_dag"):
             pt_dict_of_named_arrays = self.actx.transform_dag(dict_of_named_arrays)
 
+        name_in_program_to_tags = {
+            name: out.tags
+            for name, out in pt_dict_of_named_arrays._data.items()}
+        name_in_program_to_axes = {
+            name: out.axes
+            for name, out in pt_dict_of_named_arrays._data.items()}
+
         with ProcessLogger(logger, "generate_loopy"):
             pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
                                                options=lp.Options(
@@ -225,7 +236,7 @@ class LazilyCompilingFunctionCaller:
                                                         .actx
                                                         .transform_loopy_program))
 
-        return pytato_program
+        return pytato_program, name_in_program_to_tags, name_in_program_to_axes
 
     def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays,
             input_id_to_name_in_program, output_id_to_name_in_program,
@@ -234,18 +245,23 @@ class LazilyCompilingFunctionCaller:
             output_id = "_pt_out"
             dict_of_named_arrays = pt.make_dict_of_named_arrays(
                 {output_id: ary_or_dict_of_named_arrays})
-            pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays)
+            pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
+                self._dag_to_transformed_loopy_prg(dict_of_named_arrays))
             return CompiledFunctionReturningArray(
                 self.actx, pytato_program,
                 input_id_to_name_in_program=input_id_to_name_in_program,
-                output_name_in_program=output_id)
+                output_tags=name_in_program_to_tags[output_id],
+                output_axes=name_in_program_to_axes[output_id],
+                output_name=output_id)
         elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays):
-            pytato_program = self._dag_to_transformed_loopy_prg(
-                ary_or_dict_of_named_arrays)
+            pytato_program, name_in_program_to_tags, name_in_program_to_axes = (
+                self._dag_to_transformed_loopy_prg(ary_or_dict_of_named_arrays))
             return CompiledFunctionReturningArrayContainer(
                     self.actx, pytato_program,
                     input_id_to_name_in_program=input_id_to_name_in_program,
                     output_id_to_name_in_program=output_id_to_name_in_program,
+                    name_in_program_to_tags=name_in_program_to_tags,
+                    name_in_program_to_axes=name_in_program_to_axes,
                     output_template=output_template)
         else:
             raise NotImplementedError(type(ary_or_dict_of_named_arrays))
@@ -312,6 +328,8 @@ class LazilyCompilingFunctionCaller:
 
 
 def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
+    from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
+
     input_kwargs_for_loopy = {}
 
     for arg_id, arg in arg_id_to_arg.items():
@@ -320,7 +338,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
         elif isinstance(arg, pt.array.DataWrapper):
             # got a Datwwrapper => simply gets its data
             arg = arg.data
-        elif isinstance(arg, cla.Array):
+        elif isinstance(arg, TaggableCLArray):
             # got a frozen array  => do nothing
             pass
         elif isinstance(arg, pt.Array):
@@ -383,9 +401,14 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
     pytato_program: pt.target.BoundProgram
     input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
     output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
+    name_in_program_to_tags: Mapping[str, FrozenSet[Tag]]
+    name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]]
     output_template: ArrayContainer
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
+        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
+        from .utils import get_cl_axes_from_pt_axes
+
         input_kwargs_for_loopy = _args_to_cl_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
 
@@ -399,7 +422,12 @@ class CompiledFunctionReturningArrayContainer(CompiledFunction):
         evt.wait()
 
         def to_output_template(keys, _):
-            return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]])
+            name_in_program = self.output_id_to_name_in_program[keys]
+            return self.actx.thaw(to_tagged_cl_array(
+                out_dict[name_in_program],
+                axes=get_cl_axes_from_pt_axes(
+                    self.name_in_program_to_axes[name_in_program]),
+                tags=self.name_in_program_to_tags[name_in_program]))
 
         return rec_keyed_map_array_container(to_output_template,
                                              self.output_template)
@@ -415,9 +443,14 @@ class CompiledFunctionReturningArray(CompiledFunction):
     actx: PytatoPyOpenCLArrayContext
     pytato_program: pt.target.BoundProgram
     input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
+    output_tags: FrozenSet[Tag]
+    output_axes: Tuple[pt.Axis, ...]
     output_name: str
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
+        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
+        from .utils import get_cl_axes_from_pt_axes
+
         input_kwargs_for_loopy = _args_to_cl_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
 
@@ -430,4 +463,7 @@ class CompiledFunctionReturningArray(CompiledFunction):
         # running out of memory. This mitigates that risk a bit, for now.
         evt.wait()
 
-        return self.actx.thaw(out_dict[self.output_name])
+        return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
+                                                 axes=get_cl_axes_from_pt_axes(
+                                                     self.output_axes),
+                                                 tags=self.output_tags))
diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py
index f14d166e48bc5e3f1194a74824e39b8d37da0c4e..2babd559856c2d3518301562fd48035297ae4641 100644
--- a/arraycontext/impl/pytato/utils.py
+++ b/arraycontext/impl/pytato/utils.py
@@ -58,6 +58,7 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
                     shape=tuple(self.rec(s) if isinstance(s, Array) else s
                                 for s in expr.shape),
                     dtype=expr.dtype,
+                    axes=expr.axes,
                     tags=expr.tags)
 
     def map_size_param(self, expr: SizeParam) -> Array: