diff --git a/grudge/array_context.py b/grudge/array_context.py
index 8ee89a35e76aad7843dc62c8bd534898648cb616..171016bfe3ccc2bd236ae7933518a2a8da930e5e 100644
--- a/grudge/array_context.py
+++ b/grudge/array_context.py
@@ -32,12 +32,14 @@ THE SOFTWARE.
 # {{{ imports
 
 from typing import (
-        TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type)
+        TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type,
+        FrozenSet)
 from dataclasses import dataclass
-
+from pytools.tag import Tag
 from meshmode.array_context import (
         PyOpenCLArrayContext as _PyOpenCLArrayContextBase,
         PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase)
+from pyrsistent import pmap
 
 import logging
 logger = logging.getLogger(__name__)
@@ -160,6 +162,8 @@ class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller):
         # }}}
 
         part_id_to_prg = {}
+        name_in_program_to_tags = pmap()
+        name_in_program_to_axes = pmap()
 
         from pytato import DictOfNamedArrays
         for part in distributed_partition.parts.values():
@@ -167,7 +171,20 @@ class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller):
                         {var_name: distributed_partition.var_name_to_result[var_name]
                             for var_name in part.output_names
                          })
-            part_id_to_prg[part.pid], _, _ = self._dag_to_transformed_loopy_prg(d)
+            (
+                part_id_to_prg[part.pid],
+                part_prg_name_to_tags,
+                part_prg_name_to_axes
+            ) = self._dag_to_transformed_loopy_prg(d)
+
+            assert not (set(name_in_program_to_tags.keys())
+                        & set(part_prg_name_to_tags.keys()))
+            assert not (set(name_in_program_to_axes.keys())
+                        & set(part_prg_name_to_axes.keys()))
+            name_in_program_to_tags = name_in_program_to_tags.update(
+                part_prg_name_to_tags)
+            name_in_program_to_axes = name_in_program_to_axes.update(
+                part_prg_name_to_axes)
 
         return _DistributedCompiledFunction(
                 actx=self.actx,
@@ -175,6 +192,8 @@ class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller):
                 part_id_to_prg=part_id_to_prg,
                 input_id_to_name_in_program=input_id_to_name_in_program,
                 output_id_to_name_in_program=output_id_to_name_in_program,
+                name_in_program_to_tags=name_in_program_to_tags,
+                name_in_program_to_axes=name_in_program_to_axes,
                 output_template=output_template)
 
 
@@ -213,6 +232,8 @@ class _DistributedCompiledFunction:
     part_id_to_prg: "Mapping[PartId, pt.target.BoundProgram]"
     input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
     output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
+    name_in_program_to_tags: Mapping[str, FrozenSet[Tag]]
+    name_in_program_to_axes: Mapping[str, Tuple["pt.Axis", ...]]
     output_template: ArrayContainer
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
@@ -223,6 +244,8 @@ class _DistributedCompiledFunction:
         """
 
         from arraycontext.impl.pytato.compile import _args_to_cl_buffers
+        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
+        from arraycontext.impl.pytato.utils import get_cl_axes_from_pt_axes
         input_args_for_prg = _args_to_cl_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
 
@@ -234,7 +257,12 @@ class _DistributedCompiledFunction:
                 input_args=input_args_for_prg)
 
         def to_output_template(keys, _):
-            return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]])
+            ary_name_in_prg = self.output_id_to_name_in_program[keys]
+            return self.actx.thaw(to_tagged_cl_array(
+                out_dict[ary_name_in_prg],
+                axes=get_cl_axes_from_pt_axes(
+                    self.name_in_program_to_axes[ary_name_in_prg]),
+                tags=self.name_in_program_to_tags[ary_name_in_prg]))
 
         from arraycontext.container.traversal import rec_keyed_map_array_container
         return rec_keyed_map_array_container(to_output_template,