diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 3ec7fe30e53ac87421003d889be9945c83f80ae1..3a44cbbc4443d2f2c66d5de907f1291990f4cb2c 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -44,6 +44,8 @@ THE SOFTWARE. from arraycontext.context import ArrayContext, _ScalarLike from arraycontext.container.traversal import rec_map_array_container +from arraycontext.metadata import NameHint + import numpy as np from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type from pytools.tag import ToTagSetConvertible @@ -351,15 +353,14 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): ary = arg if name is not None: - from pytato.tags import PrefixNamed - # Tagging Placeholders with naming-related tags is pointless: # They already have names. It's also counterproductive, as # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if not isinstance(ary, pt.Placeholder): - ary = ary.tagged(PrefixNamed(name)) + if (not isinstance(ary, pt.Placeholder) + and not ary.tags_of_type(NameHint)): + ary = ary.tagged(NameHint(name)) return ary @@ -492,15 +493,14 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): ary = arg if name is not None: - from pytato.tags import PrefixNamed - # Tagging Placeholders with naming-related tags is pointless: # They already have names. It's also counterproductive, as # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if not isinstance(ary, pt.Placeholder): - ary = ary.tagged(PrefixNamed(name)) + if (not isinstance(ary, pt.Placeholder) + and not ary.tags_of_type(NameHint)): + ary = ary.tagged(NameHint(name)) return ary