diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 8879f566c2cb813afb34b7301f0de9f39e65673e..e76498b21bd72b3aea2a4532c18f53bb3db524e8 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -34,7 +34,7 @@ import pymbolic.primitives as prim from pymbolic import var from typing import (Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set, - Any, List, Type) + Any, List, Type, TYPE_CHECKING) from pytato.array import (Array, DictOfNamedArrays, ShapeType, IndexLambda, @@ -624,6 +624,13 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper): def __init__(self, codegen_mapper: CodeGenMapper): self.codegen_mapper = codegen_mapper + if TYPE_CHECKING: + def __call__(self, expr: ScalarExpression, + prstnt_ctx: PersistentExpressionContext, + local_ctx: Optional[LocalExpressionContext], + ) -> ScalarExpression: + return self.rec(expr, prstnt_ctx, local_ctx) + def map_subscript(self, expr: prim.Subscript, prstnt_ctx: PersistentExpressionContext, local_ctx: LocalExpressionContext, diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 12b9fd371631691c1793de01aa82b842079f66dd..e192123f1bd439d416862ea8187bee8f03548205 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -32,7 +32,7 @@ import logging import numpy as np from immutables import Map from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, - List, Mapping, Iterable, Tuple, Optional, + List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, Hashable) from pytato.array import ( @@ -215,6 +215,12 @@ class CachedMapper(Mapper, Generic[CachedMapperT]): # type-ignore-reason: Mapper.rec has imprecise func. signature return result # type: ignore[no-any-return] + if TYPE_CHECKING: + # type-ignore-reason: incompatible with super class + def __call__(self, expr: ArrayOrNames # type: ignore[override] + ) -> CachedMapperT: + return self.rec(expr) + # }}} @@ -231,6 +237,14 @@ class CopyMapper(CachedMapper[ArrayOrNames]): This does not copy the data of a :class:`pytato.array.DataWrapper`. """ + if TYPE_CHECKING: + # type-ignore-reason: specialized variant of super-class' rec method + def rec(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: + # type-ignore-reason: CachedMapper.rec's return type is imprecise + return super().rec(expr) # type: ignore[return-value] + + __call__ = rec def clone_for_callee(self: _SelfMapper) -> _SelfMapper: """ @@ -1215,6 +1229,9 @@ class CachedMapAndCopyMapper(CopyMapper): # type-ignore-reason: map_fn has imprecise types return result # type: ignore[return-value] + if TYPE_CHECKING: + __call__ = rec + # }}} diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 79e549749846c2617b485fe9ff9de12bd30baf24..698566ced71b86f0b0580d9dde14a1606d2f4975 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -28,7 +28,7 @@ THE SOFTWARE. import pymbolic.primitives as prim -from typing import List, Any, Dict, Tuple, TypeVar +from typing import List, Any, Dict, Tuple, TypeVar, TYPE_CHECKING from immutables import Map from pytools import UniqueNameGenerator from pytato.array import (Array, IndexLambda, Stack, Concatenate, @@ -83,6 +83,15 @@ class ToIndexLambdaMixin: else s for s in shape) + if TYPE_CHECKING: + def rec( + self, expr: ToIndexLambdaT, *args: Any, + **kwargs: Any) -> ToIndexLambdaT: + # type-ignore-reason: mypy is right as we are attempting to make + # guarantees about other super-classes. + return super().rec( # type: ignore[no-any-return,misc] + expr, *args, **kwargs) + def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: return IndexLambda(expr=expr.expr, shape=self._rec_shape(expr.shape),