diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 744fcbc4ec791133d19d641dbb1d43c3447894f2..42619e8eb738da4a104c22bac0f1c4a83d573559 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -42,6 +42,7 @@ THE SOFTWARE. """ from arraycontext.context import ArrayContext, _ScalarLike +from arraycontext.container.traversal import rec_map_array_container import numpy as np from typing import Any, Callable, Union, Sequence, TYPE_CHECKING from pytools.tag import Tag @@ -207,7 +208,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return dag def tag(self, tags: Union[Sequence[Tag], Tag], array): - return array.tagged(tags) + return rec_map_array_container(lambda x: x.tagged(tags), + array) def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): # TODO