From c91747d508e482b23f2bf6fe3cb41e4938171601 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Mon, 24 Jul 2023 16:16:38 -0700
Subject: [PATCH] restore forwarding for type checking only

---
 pytato/target/loopy/codegen.py            |  9 ++++++++-
 pytato/transform/__init__.py              | 19 ++++++++++++++++++-
 pytato/transform/lower_to_index_lambda.py | 11 ++++++++++-
 3 files changed, 36 insertions(+), 3 deletions(-)

diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py
index 8879f56..e76498b 100644
--- a/pytato/target/loopy/codegen.py
+++ b/pytato/target/loopy/codegen.py
@@ -34,7 +34,7 @@ import pymbolic.primitives as prim
 from pymbolic import var
 
 from typing import (Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set,
-                    Any, List, Type)
+                    Any, List, Type, TYPE_CHECKING)
 
 
 from pytato.array import (Array, DictOfNamedArrays, ShapeType, IndexLambda,
@@ -624,6 +624,13 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper):
     def __init__(self, codegen_mapper: CodeGenMapper):
         self.codegen_mapper = codegen_mapper
 
+    if TYPE_CHECKING:
+        def __call__(self, expr: ScalarExpression,
+                     prstnt_ctx: PersistentExpressionContext,
+                     local_ctx: Optional[LocalExpressionContext],
+                     ) -> ScalarExpression:
+            return self.rec(expr, prstnt_ctx, local_ctx)
+
     def map_subscript(self, expr: prim.Subscript,
                       prstnt_ctx: PersistentExpressionContext,
                       local_ctx: LocalExpressionContext,
diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py
index 12b9fd3..e192123 100644
--- a/pytato/transform/__init__.py
+++ b/pytato/transform/__init__.py
@@ -32,7 +32,7 @@ import logging
 import numpy as np
 from immutables import Map
 from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic,
-                    List, Mapping, Iterable, Tuple, Optional,
+                    List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING,
                     Hashable)
 
 from pytato.array import (
@@ -215,6 +215,12 @@ class CachedMapper(Mapper, Generic[CachedMapperT]):
             # type-ignore-reason: Mapper.rec has imprecise func. signature
             return result  # type: ignore[no-any-return]
 
+    if TYPE_CHECKING:
+        # type-ignore-reason: incompatible with super class
+        def __call__(self, expr: ArrayOrNames  # type: ignore[override]
+                     ) -> CachedMapperT:
+            return self.rec(expr)
+
 # }}}
 
 
@@ -231,6 +237,14 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
 
        This does not copy the data of a :class:`pytato.array.DataWrapper`.
     """
+    if TYPE_CHECKING:
+        # type-ignore-reason: specialized variant of super-class' rec method
+        def rec(self,  # type: ignore[override]
+                expr: CopyMapperResultT) -> CopyMapperResultT:
+            # type-ignore-reason: CachedMapper.rec's return type is imprecise
+            return super().rec(expr)  # type: ignore[return-value]
+
+        __call__ = rec
 
     def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
         """
@@ -1215,6 +1229,9 @@ class CachedMapAndCopyMapper(CopyMapper):
         # type-ignore-reason: map_fn has imprecise types
         return result  # type: ignore[return-value]
 
+    if TYPE_CHECKING:
+        __call__ = rec
+
 # }}}
 
 
diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py
index 79e5497..698566c 100644
--- a/pytato/transform/lower_to_index_lambda.py
+++ b/pytato/transform/lower_to_index_lambda.py
@@ -28,7 +28,7 @@ THE SOFTWARE.
 
 import pymbolic.primitives as prim
 
-from typing import List, Any, Dict, Tuple, TypeVar
+from typing import List, Any, Dict, Tuple, TypeVar, TYPE_CHECKING
 from immutables import Map
 from pytools import UniqueNameGenerator
 from pytato.array import (Array, IndexLambda, Stack, Concatenate,
@@ -83,6 +83,15 @@ class ToIndexLambdaMixin:
                      else s
                      for s in shape)
 
+    if TYPE_CHECKING:
+        def rec(
+                self, expr: ToIndexLambdaT, *args: Any,
+                **kwargs: Any) -> ToIndexLambdaT:
+            # type-ignore-reason: mypy is right as we are attempting to make
+            # guarantees about other super-classes.
+            return super().rec(  # type: ignore[no-any-return,misc]
+                expr, *args, **kwargs)
+
     def map_index_lambda(self, expr: IndexLambda) -> IndexLambda:
         return IndexLambda(expr=expr.expr,
                            shape=self._rec_shape(expr.shape),
-- 
GitLab