diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py
index aa7d12d583cba1193d2ab03eeda115f7a273db41..439ca580594bab5befdc7e4e5c7d8c3fad403490 100644
--- a/arraycontext/impl/pyopencl/taggable_cl_array.py
+++ b/arraycontext/impl/pyopencl/taggable_cl_array.py
@@ -19,9 +19,9 @@ class Axis(Taggable):
     """
     tags: FrozenSet[Tag]
 
-    def copy(self, **kwargs):
+    def _with_new_tags(self, tags: FrozenSet[Tag]) -> "Axis":
         from dataclasses import replace
-        return replace(self, **kwargs)
+        return replace(self, tags=tags)
 
 
 @memoize