diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 1e00c8d697c01e6a252e49e64ed01dadf1e36f1b..184bdd90f33a3318476273362ae147025c5e6059 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -25,7 +25,7 @@ THE SOFTWARE. from typing import Any, Dict, Set, Tuple, Mapping from pytato.array import SizeParam, Placeholder -from pytato.array import Array, DataWrapper +from pytato.array import Array, DataWrapper, DictOfNamedArrays from pytato.transform import CopyMapper from pytools import UniqueNameGenerator @@ -66,8 +66,8 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): " DatawrapperToBoundPlaceholderMapper.") -def _normalize_pt_expr(expr: Array) -> Tuple[Array, - Mapping[str, Any]]: +def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays, + Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -78,5 +78,6 @@ def _normalize_pt_expr(expr: Array) -> Tuple[Array, equivalent graphs. """ normalize_mapper = _DatawrapperToBoundPlaceholderMapper() - normalized_expr = normalize_mapper(expr) + # type-ignore reason: Mapper.__call__ takes Array, passed DictOfNamedArrays + normalized_expr = normalize_mapper(expr) # type: ignore return normalized_expr, normalize_mapper.bound_arguments