From 35dc261453985a96df1cb05aa4d4c06d0b25eb9e Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sat, 14 Aug 2021 16:19:27 -0500
Subject: [PATCH] _normalize_pt_expr: perform the transformation for a
 dict-of-named-arrays

---
 arraycontext/impl/pytato/utils.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py
index 1e00c8d..184bdd9 100644
--- a/arraycontext/impl/pytato/utils.py
+++ b/arraycontext/impl/pytato/utils.py
@@ -25,7 +25,7 @@ THE SOFTWARE.
 
 from typing import Any, Dict, Set, Tuple, Mapping
 from pytato.array import SizeParam, Placeholder
-from pytato.array import Array, DataWrapper
+from pytato.array import Array, DataWrapper, DictOfNamedArrays
 from pytato.transform import CopyMapper
 from pytools import UniqueNameGenerator
 
@@ -66,8 +66,8 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
                          " DatawrapperToBoundPlaceholderMapper.")
 
 
-def _normalize_pt_expr(expr: Array) -> Tuple[Array,
-                                            Mapping[str, Any]]:
+def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays,
+                                                         Mapping[str, Any]]:
     """
     Returns ``(normalized_expr, bound_arguments)``.  *normalized_expr* is a
     normalized form of *expr*, with all instances of
@@ -78,5 +78,6 @@ def _normalize_pt_expr(expr: Array) -> Tuple[Array,
     equivalent graphs.
     """
     normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
-    normalized_expr = normalize_mapper(expr)
+    # type-ignore reason: Mapper.__call__ takes Array, passed DictOfNamedArrays
+    normalized_expr = normalize_mapper(expr)  # type: ignore
     return normalized_expr, normalize_mapper.bound_arguments
-- 
GitLab