diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 3a44cbbc4443d2f2c66d5de907f1291990f4cb2c..211f0c4ba932901d924a7ca599b010c4c1f020df 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -47,14 +47,44 @@ from arraycontext.container.traversal import rec_map_array_container from arraycontext.metadata import NameHint import numpy as np -from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type -from pytools.tag import ToTagSetConvertible +from typing import Any, Callable, Union, TYPE_CHECKING, Tuple, Type, FrozenSet +from pytools.tag import ToTagSetConvertible, normalize_tags, Tag import abc if TYPE_CHECKING: import pytato +# {{{ tag conversion + +def _preprocess_array_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]: + tags = normalize_tags(tags) + + name_hints = [tag for tag in tags if isinstance(tag, NameHint)] + if name_hints: + name_hint, = name_hints + + from pytato.tags import PrefixNamed + prefix_nameds = [tag for tag in tags if isinstance(tag, PrefixNamed)] + + if prefix_nameds: + prefix_named, = prefix_nameds + from warnings import warn + warn("When converting a " + f"arraycontext.metadata.NameHint('{name_hint.name}') " + "to pytato.tags.PrefixNamed, " + f"PrefixNamed('{prefix_named.prefix}') " + "was already present.") + + tags = ( + (tags | frozenset({PrefixNamed(name_hint.name)})) + - {name_hint}) + + return tags + +# }}} + + # {{{ _BasePytatoArrayContext class _BasePytatoArrayContext(ArrayContext, abc.ABC): @@ -320,8 +350,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): return dag def tag(self, tags: ToTagSetConvertible, array): - return rec_map_array_container(lambda x: x.tagged(tags), - array) + return rec_map_array_container( + lambda x: x.tagged(_preprocess_array_tags(tags)), + array) def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): return rec_map_array_container( @@ -367,7 +398,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): return pt.einsum(spec, *[ preprocess_arg(name, arg) for name, arg in zip(arg_names, args) - ]) + ]).tagged(_preprocess_array_tags(tagged)) # }}} @@ -461,7 +492,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): return ary else: assert isinstance(ary, pt.Array) - return ary.tagged(tags) + return ary.tagged(_preprocess_array_tags(tags)) return rec_map_array_container(_rec_tag, array) @@ -507,7 +538,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): return pt.einsum(spec, *[ preprocess_arg(name, arg) for name, arg in zip(arg_names, args) - ]) + ]).tagged(_preprocess_array_tags(tagged)) # }}}