From 63ef939151064208b3db69a0c33ab8c8761741dc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 7 Jun 2022 19:49:52 -0500 Subject: [PATCH] Einsum array argument naming: Use NameHint --- arraycontext/impl/pytato/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 3ec7fe3..3a44cbb 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -44,6 +44,8 @@ THE SOFTWARE. from arraycontext.context import ArrayContext, _ScalarLike from arraycontext.container.traversal import rec_map_array_container +from arraycontext.metadata import NameHint + import numpy as np from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type from pytools.tag import ToTagSetConvertible @@ -351,15 +353,14 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): ary = arg if name is not None: - from pytato.tags import PrefixNamed - # Tagging Placeholders with naming-related tags is pointless: # They already have names. It's also counterproductive, as # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if not isinstance(ary, pt.Placeholder): - ary = ary.tagged(PrefixNamed(name)) + if (not isinstance(ary, pt.Placeholder) + and not ary.tags_of_type(NameHint)): + ary = ary.tagged(NameHint(name)) return ary @@ -492,15 +493,14 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): ary = arg if name is not None: - from pytato.tags import PrefixNamed - # Tagging Placeholders with naming-related tags is pointless: # They already have names. It's also counterproductive, as # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if not isinstance(ary, pt.Placeholder): - ary = ary.tagged(PrefixNamed(name)) + if (not isinstance(ary, pt.Placeholder) + and not ary.tags_of_type(NameHint)): + ary = ary.tagged(NameHint(name)) return ary -- GitLab