diff --git a/pytato/transform.py b/pytato/transform.py index ce7552671cedc70e792396cb1a8e4997d4d00275..79ba5c91c149ed20798cf8dd680be2fe111537b4 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -24,7 +24,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic +from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, + List) from pytato.array import ( Array, IndexLambda, Placeholder, MatrixProduct, Stack, Roll, @@ -566,13 +567,14 @@ class TopoSortMapper(CachedWalkMapper): def __init__(self) -> None: super().__init__() - self.topological_order = [] + self.topological_order: List[Array] = [] - def post_visit(self, expr: Any) -> bool: + def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) # }}} + # {{{ mapper frontends def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays,