From 84a84ced8e9d26f73d366df4091a2e56bf3b8c20 Mon Sep 17 00:00:00 2001 From: Matthew Smith <mjsmith6@illinois.edu> Date: Thu, 26 Aug 2021 10:03:59 -0500 Subject: [PATCH] add ArrayT to fix mypy without limiting ContainerT usefulness --- arraycontext/container/__init__.py | 10 ++++++++-- arraycontext/container/traversal.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 1046f99..ea208cb 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -3,6 +3,11 @@ """ .. currentmodule:: arraycontext +.. class:: ArrayT + :canonical: arraycontext.container.ArrayT + + :class:`~typing.TypeVar` for arrays. + .. class:: ContainerT :canonical: arraycontext.container.ContainerT @@ -54,11 +59,12 @@ THE SOFTWARE. from functools import singledispatch from arraycontext.context import ArrayContext -from typing import Any, Iterable, Tuple, TypeVar, Optional +from typing import Any, Iterable, Tuple, TypeVar, Optional, Union import numpy as np +ArrayT = TypeVar("ArrayT") ContainerT = TypeVar("ContainerT") -ArrayOrContainerT = TypeVar("ArrayOrContainerT") +ArrayOrContainerT = Union[ArrayT, ContainerT] # {{{ ArrayContainer diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 76f7647..ed6dc0c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -142,13 +142,13 @@ def _multimap_array_container_impl( if len(container_indices) == 1 and reduce_func is None: # NOTE: if we just have one ArrayContainer in args, passing it through # _map_array_container_impl should be faster - def wrapper(ary: ArrayOrContainerT) -> ArrayOrContainerT: + def wrapper(ary: ContainerT) -> ContainerT: new_args = list(args) new_args[container_indices[0]] = ary return f(*new_args) update_wrapper(wrapper, f) - template_ary: ArrayOrContainerT = args[container_indices[0]] + template_ary: ContainerT = args[container_indices[0]] return _map_array_container_impl( wrapper, template_ary, leaf_cls=leaf_cls, recursive=recursive) -- GitLab