diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 42619e8eb738da4a104c22bac0f1c4a83d573559..f7592520ccdf98b8fa6cc8e05c743c9969dea854 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -212,11 +212,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): array) def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): - # TODO - from warnings import warn - warn("tagging PytatoPyOpenCLArrayContext's array axes: not yet implemented", - stacklevel=2) - return array + return rec_map_array_container(lambda x: x.with_tagged_axis(iaxis, + tags), + array) def einsum(self, spec, *args, arg_names=None, tagged=()): import pyopencl.array as cla