diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py
index ecde87d14cf19c4cce6761a7623991318879d971..d34344788695e783ff5f5baa6f3dd2388e8f2115 100644
--- a/arraycontext/impl/pyopencl/taggable_cl_array.py
+++ b/arraycontext/impl/pyopencl/taggable_cl_array.py
@@ -6,8 +6,8 @@
 """
 
 import pyopencl.array as cla
-from typing import FrozenSet, Union, Sequence, Optional, Tuple
-from pytools.tag import Taggable, Tag
+from typing import Any, Dict, FrozenSet, Optional, Tuple
+from pytools.tag import Taggable, Tag, TagsType, TagOrIterableType
 from dataclasses import dataclass
 from pytools import memoize
 
@@ -29,6 +29,20 @@ def _construct_untagged_axes(ndim: int) -> Tuple[Axis, ...]:
     return tuple(Axis(frozenset()) for _ in range(ndim))
 
 
+def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]:
+    return dict(shape=ary.shape, dtype=ary.dtype,
+                allocator=ary.allocator,
+                strides=ary.strides,
+                data=ary.base_data,
+                offset=ary.offset,
+                events=ary.events,
+                _context=ary.context,
+                _queue=ary.queue,
+                _size=ary.size,
+                _fast=True,
+                )
+
+
 class TaggableCLArray(cla.Array, Taggable):
     """
     A :class:`pyopencl.array.Array` with additional metadata. This is used by
@@ -60,58 +74,40 @@ class TaggableCLArray(cla.Array, Taggable):
                          _size=_size, _context=_context,
                          _queue=_queue)
 
+        if __debug__:
+            if not isinstance(tags, frozenset):
+                raise TypeError("tags are not a frozenset")
+
+            if axes is not None and len(axes) != self.ndim:
+                raise ValueError("axes length does not match array dimension: "
+                                 f"got {len(axes)} axes for {self.ndim}d array")
+
+        if axes is None:
+            axes = _construct_untagged_axes(self.ndim)
+
         self.tags = tags
-        axes = axes if axes is not None else _construct_untagged_axes(len(self
-                                                                          .shape))
         self.axes = axes
 
-    def copy(self, queue=cla._copy_queue, tags=None, axes=None, _new_class=None):
-        """
-        :arg _new_class: The class of the copy. :func:`to_tagged_cl_array` is
-            sets this to convert instances of :class:`pyopencl.array.Array` to
-            :class:`TaggableCLArray`. If not provided, defaults to
-            ``self.__class__``.
-        """
-        _new_class = self.__class__ if _new_class is None else _new_class
-
-        if queue is not cla._copy_queue:
-            # Copying command queue is an involved operation, use super-class'
-            # implementation.
-            base_instance = super().copy(queue=queue)
-        else:
-            base_instance = self
-
-        if tags is None and axes is None and _new_class is self.__class__:
-            # early exit
-            return base_instance
-
-        tags = getattr(base_instance, "tags", frozenset()) if tags is None else tags
-        axes = getattr(base_instance, "axes", None) if axes is None else axes
-
-        return _new_class(None,
-                          base_instance.shape,
-                          base_instance.dtype,
-                          allocator=base_instance.allocator,
-                          strides=base_instance.strides,
-                          data=base_instance.base_data,
-                          offset=base_instance.offset,
-                          events=base_instance.events, _fast=True,
-                          _context=base_instance.context,
-                          _queue=base_instance.queue,
-                          _size=base_instance.size,
-                          tags=tags,
-                          axes=axes,
-                          )
+    def copy(self, queue=cla._copy_queue):
+        ary = super().copy(queue=queue)
+        return type(self)(None, tags=self.tags, axes=self.axes,
+                          **_unwrap_cl_array(ary))
+
+    def _with_new_tags(self, tags: TagsType) -> "TaggableCLArray":
+        return type(self)(None, tags=tags, axes=self.axes,
+                          **_unwrap_cl_array(self))
 
     def with_tagged_axis(self, iaxis: int,
-                         tags: Union[Sequence[Tag], Tag]) -> "TaggableCLArray":
+                         tags: TagOrIterableType) -> "TaggableCLArray":
         """
         Returns a copy of *self* with *iaxis*-th axis tagged with *tags*.
         """
         new_axes = (self.axes[:iaxis]
                     + (self.axes[iaxis].tagged(tags),)
                     + self.axes[iaxis+1:])
-        return self.copy(axes=new_axes)
+
+        return type(self)(None, tags=self.tags, axes=new_axes,
+                          **_unwrap_cl_array(self))
 
 
 def to_tagged_cl_array(ary: cla.Array,
@@ -119,11 +115,32 @@ def to_tagged_cl_array(ary: cla.Array,
                        tags: FrozenSet[Tag]) -> TaggableCLArray:
     """
     Returns a :class:`TaggableCLArray` that is constructed from the data in
-    *ary* along with the metadata from *axes* and *tags*.
+    *ary* along with the metadata from *axes* and *tags*. If *ary* is already a
+    :class:`TaggableCLArray`, the new *tags* and *axes* are added to the
+    existing ones.
 
     :arg axes: An instance of :class:`Axis` for each dimension of the
         array. If passed *None*, then initialized to a :class:`pytato.Axis`
         with no tags attached for each dimension.
     """
-    return TaggableCLArray.copy(ary, axes=axes, tags=tags,
-                                _new_class=TaggableCLArray)
+    if axes and len(axes) != ary.ndim:
+        raise ValueError("axes length does not match array dimension: "
+                         f"got {len(axes)} axes for {ary.ndim}d array")
+
+    from pytools.tag import normalize_tags
+    tags = normalize_tags(tags)
+
+    if isinstance(ary, TaggableCLArray):
+        if axes:
+            for i, axis in enumerate(axes):
+                ary = ary.with_tagged_axis(i, axis.tags)
+
+        if tags:
+            ary = ary.tagged(tags)
+
+        return ary
+    elif isinstance(ary, cla.Array):
+        return TaggableCLArray(None, tags=tags, axes=axes,
+                               **_unwrap_cl_array(ary))
+    else:
+        raise TypeError(f"unsupported array type: '{type(ary).__name__}'")
diff --git a/setup.py b/setup.py
index 62ff4a7bae995e792efb219c3febb77e9f6f0066..8b0d677b49f2fc633a4ce3b1da8d8f8a146df57e 100644
--- a/setup.py
+++ b/setup.py
@@ -39,7 +39,10 @@ def main():
         python_requires="~=3.6",
         install_requires=[
             "numpy",
-            "pytools>=2020.4.1",
+
+            # https://github.com/inducer/arraycontext/pull/147
+            "pytools>=2022.1.1",
+
             "pytest>=2.3",
             "loopy>=2019.1",
             "dataclasses; python_version<'3.7'",
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 92195f5aabb19796944e86e1e37410c6afc5b100..54acefceafe5a00b989e3cdcc0af8e1a6ae9b7ac 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -1415,6 +1415,8 @@ def test_array_container_with_numpy(actx_factory):
 # }}}
 
 
+# {{{ test_actx_compile_on_pure_array_return
+
 def test_actx_compile_on_pure_array_return(actx_factory):
     def _twice(x):
         return 2 * x
@@ -1424,6 +1426,55 @@ def test_actx_compile_on_pure_array_return(actx_factory):
     np.testing.assert_allclose(actx.to_numpy(_twice(ones)),
                                actx.to_numpy(actx.compile(_twice)(ones)))
 
+# }}}
+
+
+# {{{
+
+def test_taggable_cl_array_tags(actx_factory):
+    actx = actx_factory()
+    if not isinstance(actx, PyOpenCLArrayContext):
+        pytest.skip(f"not relevant for '{type(actx).__name__}'")
+
+    import pyopencl.array as cl_array
+    ary = cl_array.to_device(actx.queue, np.zeros((32, 7)))
+
+    # {{{ check tags are set
+
+    from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
+    tagged_ary = to_tagged_cl_array(ary, axes=None,
+                                    tags=frozenset((FirstAxisIsElementsTag(),)))
+
+    assert tagged_ary.base_data is ary.base_data
+    assert tagged_ary.tags == frozenset((FirstAxisIsElementsTag(),))
+
+    # }}}
+
+    # {{{ check tags are appended
+
+    from arraycontext import ElementwiseMapKernelTag
+    tagged_ary = to_tagged_cl_array(tagged_ary, axes=None,
+                                    tags=frozenset((ElementwiseMapKernelTag(),)))
+
+    assert tagged_ary.base_data is ary.base_data
+    assert tagged_ary.tags == frozenset(
+        (FirstAxisIsElementsTag(), ElementwiseMapKernelTag())
+    )
+
+    # }}}
+
+    # {{{ test copied tags
+
+    copy_tagged_ary = tagged_ary.copy()
+
+    assert copy_tagged_ary.tags == tagged_ary.tags
+    assert copy_tagged_ary.axes == tagged_ary.axes
+    assert copy_tagged_ary.base_data != tagged_ary.base_data
+
+    # }}}
+
+# }}}
+
 
 if __name__ == "__main__":
     import sys