From 84a84ced8e9d26f73d366df4091a2e56bf3b8c20 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Thu, 26 Aug 2021 10:03:59 -0500
Subject: [PATCH] add ArrayT to fix mypy without limiting ContainerT usefulness

---
 arraycontext/container/__init__.py  | 10 ++++++++--
 arraycontext/container/traversal.py |  4 ++--
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 1046f99..ea208cb 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -3,6 +3,11 @@
 """
 .. currentmodule:: arraycontext
 
+.. class:: ArrayT
+    :canonical: arraycontext.container.ArrayT
+
+    :class:`~typing.TypeVar` for arrays.
+
 .. class:: ContainerT
     :canonical: arraycontext.container.ContainerT
 
@@ -54,11 +59,12 @@ THE SOFTWARE.
 
 from functools import singledispatch
 from arraycontext.context import ArrayContext
-from typing import Any, Iterable, Tuple, TypeVar, Optional
+from typing import Any, Iterable, Tuple, TypeVar, Optional, Union
 import numpy as np
 
+ArrayT = TypeVar("ArrayT")
 ContainerT = TypeVar("ContainerT")
-ArrayOrContainerT = TypeVar("ArrayOrContainerT")
+ArrayOrContainerT = Union[ArrayT, ContainerT]
 
 
 # {{{ ArrayContainer
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 76f7647..ed6dc0c 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -142,13 +142,13 @@ def _multimap_array_container_impl(
     if len(container_indices) == 1 and reduce_func is None:
         # NOTE: if we just have one ArrayContainer in args, passing it through
         # _map_array_container_impl should be faster
-        def wrapper(ary: ArrayOrContainerT) -> ArrayOrContainerT:
+        def wrapper(ary: ContainerT) -> ContainerT:
             new_args = list(args)
             new_args[container_indices[0]] = ary
             return f(*new_args)
 
         update_wrapper(wrapper, f)
-        template_ary: ArrayOrContainerT = args[container_indices[0]]
+        template_ary: ContainerT = args[container_indices[0]]
         return _map_array_container_impl(
                 wrapper, template_ary,
                 leaf_cls=leaf_cls, recursive=recursive)
-- 
GitLab