From e8a57ccb6cb8b2fea165081c227d8a5537193d20 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Thu, 20 Jul 2023 07:16:28 -0700
Subject: [PATCH] eliminate recursive forwarding in mappers

---
 pytato/target/loopy/codegen.py            |  6 ------
 pytato/transform/__init__.py              | 24 +----------------------
 pytato/transform/lower_to_index_lambda.py |  5 -----
 3 files changed, 1 insertion(+), 34 deletions(-)

diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py
index 7d8760e..8879f56 100644
--- a/pytato/target/loopy/codegen.py
+++ b/pytato/target/loopy/codegen.py
@@ -624,12 +624,6 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper):
     def __init__(self, codegen_mapper: CodeGenMapper):
         self.codegen_mapper = codegen_mapper
 
-    def __call__(self, expr: ScalarExpression,
-                 prstnt_ctx: PersistentExpressionContext,
-                 local_ctx: Optional[LocalExpressionContext],
-                 ) -> ScalarExpression:
-        return self.rec(expr, prstnt_ctx, local_ctx)
-
     def map_subscript(self, expr: prim.Subscript,
                       prstnt_ctx: PersistentExpressionContext,
                       local_ctx: LocalExpressionContext,
diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py
index 953cbfa..12b9fd3 100644
--- a/pytato/transform/__init__.py
+++ b/pytato/transform/__init__.py
@@ -185,9 +185,7 @@ class Mapper:
         assert method is not None
         return method(expr, *args, **kwargs)
 
-    def __call__(self, expr: MappedT, *args: Any, **kwargs: Any) -> Any:
-        """Handle the mapping of *expr*."""
-        return self.rec(expr, *args, **kwargs)
+    __call__ = rec
 
 # }}}
 
@@ -217,11 +215,6 @@ class CachedMapper(Mapper, Generic[CachedMapperT]):
             # type-ignore-reason: Mapper.rec has imprecise func. signature
             return result  # type: ignore[no-any-return]
 
-    # type-ignore-reason: incompatible with super class
-    def __call__(self, expr: ArrayOrNames  # type: ignore[override]
-                 ) -> CachedMapperT:
-        return self.rec(expr)
-
 # }}}
 
 
@@ -239,17 +232,6 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
        This does not copy the data of a :class:`pytato.array.DataWrapper`.
     """
 
-    # type-ignore-reason: specialized variant of super-class' rec method
-    def rec(self,  # type: ignore[override]
-            expr: CopyMapperResultT) -> CopyMapperResultT:
-        # type-ignore-reason: CachedMapper.rec's return type is imprecise
-        return super().rec(expr)  # type: ignore[return-value]
-
-    # type-ignore-reason: specialized variant of super-class' rec method
-    def __call__(self,  # type: ignore[override]
-                 expr: CopyMapperResultT) -> CopyMapperResultT:
-        return self.rec(expr)
-
     def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
         """
         Called to clone *self* before starting traversal of a
@@ -1233,10 +1215,6 @@ class CachedMapAndCopyMapper(CopyMapper):
         # type-ignore-reason: map_fn has imprecise types
         return result  # type: ignore[return-value]
 
-    # type-ignore-reason: Mapper.__call__ returns Any
-    def __call__(self, expr: MappedT) -> MappedT:  # type: ignore[override]
-        return self.rec(expr)
-
 # }}}
 
 
diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py
index a2bb443..79e5497 100644
--- a/pytato/transform/lower_to_index_lambda.py
+++ b/pytato/transform/lower_to_index_lambda.py
@@ -83,11 +83,6 @@ class ToIndexLambdaMixin:
                      else s
                      for s in shape)
 
-    def rec(self, expr: ToIndexLambdaT, *args: Any, **kwargs: Any) -> ToIndexLambdaT:
-        # type-ignore-reason: mypy is right as we are attempting to make
-        # guarantees about other super-classes.
-        return super().rec(expr, *args, **kwargs)  # type: ignore[no-any-return,misc]
-
     def map_index_lambda(self, expr: IndexLambda) -> IndexLambda:
         return IndexLambda(expr=expr.expr,
                            shape=self._rec_shape(expr.shape),
-- 
GitLab