From c91747d508e482b23f2bf6fe3cb41e4938171601 Mon Sep 17 00:00:00 2001 From: Matthew Smith <mjsmith6@illinois.edu> Date: Mon, 24 Jul 2023 16:16:38 -0700 Subject: [PATCH] restore forwarding for type checking only --- pytato/target/loopy/codegen.py | 9 ++++++++- pytato/transform/__init__.py | 19 ++++++++++++++++++- pytato/transform/lower_to_index_lambda.py | 11 ++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 8879f56..e76498b 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -34,7 +34,7 @@ import pymbolic.primitives as prim from pymbolic import var from typing import (Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set, - Any, List, Type) + Any, List, Type, TYPE_CHECKING) from pytato.array import (Array, DictOfNamedArrays, ShapeType, IndexLambda, @@ -624,6 +624,13 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): def __init__(self, codegen_mapper: CodeGenMapper): self.codegen_mapper = codegen_mapper + if TYPE_CHECKING: + def __call__(self, expr: ScalarExpression, + prstnt_ctx: PersistentExpressionContext, + local_ctx: Optional[LocalExpressionContext], + ) -> ScalarExpression: + return self.rec(expr, prstnt_ctx, local_ctx) + def map_subscript(self, expr: prim.Subscript, prstnt_ctx: PersistentExpressionContext, local_ctx: LocalExpressionContext, diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 12b9fd3..e192123 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -32,7 +32,7 @@ import logging import numpy as np from immutables import Map from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, - List, Mapping, Iterable, Tuple, Optional, + List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, Hashable) from pytato.array import ( @@ -215,6 +215,12 @@ class CachedMapper(Mapper, Generic[CachedMapperT]): # type-ignore-reason: Mapper.rec has imprecise func. signature return result # type: ignore[no-any-return] + if TYPE_CHECKING: + # type-ignore-reason: incompatible with super class + def __call__(self, expr: ArrayOrNames # type: ignore[override] + ) -> CachedMapperT: + return self.rec(expr) + # }}} @@ -231,6 +237,14 @@ class CopyMapper(CachedMapper[ArrayOrNames]): This does not copy the data of a :class:`pytato.array.DataWrapper`. """ + if TYPE_CHECKING: + # type-ignore-reason: specialized variant of super-class' rec method + def rec(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: + # type-ignore-reason: CachedMapper.rec's return type is imprecise + return super().rec(expr) # type: ignore[return-value] + + __call__ = rec def clone_for_callee(self: _SelfMapper) -> _SelfMapper: """ @@ -1215,6 +1229,9 @@ class CachedMapAndCopyMapper(CopyMapper): # type-ignore-reason: map_fn has imprecise types return result # type: ignore[return-value] + if TYPE_CHECKING: + __call__ = rec + # }}} diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 79e5497..698566c 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -28,7 +28,7 @@ THE SOFTWARE. import pymbolic.primitives as prim -from typing import List, Any, Dict, Tuple, TypeVar +from typing import List, Any, Dict, Tuple, TypeVar, TYPE_CHECKING from immutables import Map from pytools import UniqueNameGenerator from pytato.array import (Array, IndexLambda, Stack, Concatenate, @@ -83,6 +83,15 @@ class ToIndexLambdaMixin: else s for s in shape) + if TYPE_CHECKING: + def rec( + self, expr: ToIndexLambdaT, *args: Any, + **kwargs: Any) -> ToIndexLambdaT: + # type-ignore-reason: mypy is right as we are attempting to make + # guarantees about other super-classes. + return super().rec( # type: ignore[no-any-return,misc] + expr, *args, **kwargs) + def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: return IndexLambda(expr=expr.expr, shape=self._rec_shape(expr.shape), -- GitLab