From 35dc261453985a96df1cb05aa4d4c06d0b25eb9e Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sat, 14 Aug 2021 16:19:27 -0500 Subject: [PATCH] _normalize_pt_expr: perform the transformation for a dict-of-named-arrays --- arraycontext/impl/pytato/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 1e00c8d..184bdd9 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -25,7 +25,7 @@ THE SOFTWARE. from typing import Any, Dict, Set, Tuple, Mapping from pytato.array import SizeParam, Placeholder -from pytato.array import Array, DataWrapper +from pytato.array import Array, DataWrapper, DictOfNamedArrays from pytato.transform import CopyMapper from pytools import UniqueNameGenerator @@ -66,8 +66,8 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): " DatawrapperToBoundPlaceholderMapper.") -def _normalize_pt_expr(expr: Array) -> Tuple[Array, - Mapping[str, Any]]: +def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays, + Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -78,5 +78,6 @@ def _normalize_pt_expr(expr: Array) -> Tuple[Array, equivalent graphs. """ normalize_mapper = _DatawrapperToBoundPlaceholderMapper() - normalized_expr = normalize_mapper(expr) + # type-ignore reason: Mapper.__call__ takes Array, passed DictOfNamedArrays + normalized_expr = normalize_mapper(expr) # type: ignore return normalized_expr, normalize_mapper.bound_arguments -- GitLab