From 5964b91fe99fb641670821edd5004c1a7cf218dd Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 1 Nov 2021 13:45:35 -0500 Subject: [PATCH] recursively add tags to all arrays in the container --- arraycontext/impl/pytato/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 744fcbc..42619e8 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -42,6 +42,7 @@ THE SOFTWARE. """ from arraycontext.context import ArrayContext, _ScalarLike +from arraycontext.container.traversal import rec_map_array_container import numpy as np from typing import Any, Callable, Union, Sequence, TYPE_CHECKING from pytools.tag import Tag @@ -207,7 +208,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return dag def tag(self, tags: Union[Sequence[Tag], Tag], array): - return array.tagged(tags) + return rec_map_array_container(lambda x: x.tagged(tags), + array) def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): # TODO -- GitLab