From 0569eab090b8caf30a9e1a00acd6636b72f79620 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 29 Jun 2022 15:18:58 -0500
Subject: [PATCH] Clean up array and container type annotations

Co-authored-by:  Michael Campbell <mtcampbe@illinois.edu>
---
 arraycontext/__init__.py               |  22 +++-
 arraycontext/container/__init__.py     |  83 ++++++++-----
 arraycontext/container/traversal.py    | 163 ++++++++++++++-----------
 arraycontext/context.py                | 101 +++++++++------
 arraycontext/impl/jax/__init__.py      |   4 +-
 arraycontext/impl/pyopencl/__init__.py |   4 +-
 arraycontext/impl/pytato/__init__.py   |  10 +-
 arraycontext/impl/pytato/compile.py    |   4 +-
 test/test_arraycontext.py              |   9 +-
 9 files changed, 245 insertions(+), 155 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index b8287cf..06e0b96 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -30,8 +30,14 @@ THE SOFTWARE.
 
 import sys
 from .context import (
+        ArrayContext,
+
+        Scalar, ScalarLike,
         Array, ArrayT,
-        ArrayContext, Scalar, tag_axes)
+        ArrayOrContainer, ArrayOrContainerT,
+        ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT,
+
+        tag_axes)
 
 from .transform_metadata import (CommonSubexpressionTag,
         ElementwiseMapKernelTag)
@@ -40,8 +46,8 @@ from .transform_metadata import (CommonSubexpressionTag,
 from .metadata import _FirstAxisIsElementsTag
 
 from .container import (
-        ArrayOrContainerT as ArrayOrContainer, ArrayOrContainerT,
-        ArrayContainer, NotAnArrayContainerError,
+        ArrayContainer, ArrayContainerT,
+        NotAnArrayContainerError,
         is_array_container, is_array_container_type,
         get_container_context_opt,
         get_container_context_recursively, get_container_context_recursively_opt,
@@ -81,14 +87,18 @@ from .loopy import make_loopy_program
 
 
 __all__ = (
+        "ArrayContext", "Scalar", "Array",
+        "Scalar", "ScalarLike",
         "Array", "ArrayT",
-        "ArrayContext", "Scalar", "tag_axes",
+        "ArrayOrContainer", "ArrayOrContainerT",
+        "ArrayOrContainerOrScalar", "ArrayOrContainerOrScalarT",
+        "tag_axes",
 
         "CommonSubexpressionTag",
         "ElementwiseMapKernelTag",
 
-        "ArrayOrContainer", "ArrayOrContainerT",
-        "ArrayContainer", "NotAnArrayContainerError",
+        "ArrayContainer", "ArrayContainerT",
+        "NotAnArrayContainerError",
         "is_array_container", "is_array_container_type",
         "get_container_context_opt",
         "get_container_context_recursively_opt",
diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 789dd29..71bccee 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -3,22 +3,10 @@
 """
 .. currentmodule:: arraycontext
 
-.. class:: ArrayT
-    :canonical: arraycontext.container.ArrayT
-
-    :class:`~typing.TypeVar` for arrays.
-
-.. class:: ContainerT
-    :canonical: arraycontext.container.ContainerT
-
-    :class:`~typing.TypeVar` for array container-like objects.
-
-.. class:: ArrayOrContainerT
-    :canonical: arraycontext.container.ArrayOrContainerT
-
-    :class:`~typing.TypeVar` for arrays or array container-like objects.
-
 .. autoclass:: ArrayContainer
+.. class:: ArrayContainerT
+
+    A type variable with a lower bound of :class:`ArrayContainer`.
 
 .. autoexception:: NotAnArrayContainerError
 
@@ -38,8 +26,23 @@ Context retrieval
 ---------------------------------------------------------
 
 .. autofunction:: register_multivector_as_array_container
+
+.. currentmodule:: arraycontext.container
+
+Canonical locations for type annotations
+----------------------------------------
+
+.. class:: ArrayContainerT
+
+    :canonical: arraycontext.ArrayContainerT
+
+.. class:: ArrayOrContainerT
+
+    :canonical: arraycontext.ArrayOrContainerT
 """
 
+from __future__ import annotations
+
 
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -67,22 +70,24 @@ THE SOFTWARE.
 
 from functools import singledispatch
 from arraycontext.context import ArrayContext
-from typing import Any, Iterable, Tuple, TypeVar, Optional, Union, TYPE_CHECKING
+from typing import Any, Iterable, Tuple, Optional, TypeVar, Protocol, TYPE_CHECKING
 import numpy as np
 
-ArrayT = TypeVar("ArrayT")
-ContainerT = TypeVar("ContainerT")
-ArrayOrContainerT = Union[ArrayT, ContainerT]
+# For use in singledispatch type annotations, because sphinx can't figure out
+# what 'np' is.
+import numpy
+
 
 if TYPE_CHECKING:
     from pymbolic.geometric_algebra import MultiVector
+    from arraycontext import ArrayOrContainer
 
 
 # {{{ ArrayContainer
 
-class ArrayContainer:
-    r"""
-    A generic container for the array type supported by the
+class ArrayContainer(Protocol):
+    """
+    A protocol for generic containers of the array type supported by the
     :class:`ArrayContext`.
 
     The functionality required for the container to operated is supplied via
@@ -113,17 +118,31 @@ class ArrayContainer:
 
     .. note::
 
-        This class is used in type annotation. Inheriting from it confers no
-        special meaning or behavior.
+        This class is used in type annotation and as a marker of array container
+        attributes for :func:`~arraycontext.dataclass_array_container`.
+        As a protocol, it is not intended as a superclass.
     """
 
+    # Array containers do not need to have any particular features, so this
+    # protocol is deliberately empty.
+
+    # This *is* used as a type annotation in dataclasses that are processed
+    # by dataclass_array_container, where it's used to recognize attributes
+    # that are container-typed.
+
+    pass
+
+
+ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
+
 
 class NotAnArrayContainerError(TypeError):
     """:class:`TypeError` subclass raised when an array container is expected."""
 
 
 @singledispatch
-def serialize_container(ary: Any) -> Iterable[Tuple[Any, Any]]:
+def serialize_container(
+        ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]:
     r"""Serialize the array container into an iterable over its components.
 
     The order of the components and their identifiers are entirely under
@@ -149,7 +168,9 @@ def serialize_container(ary: Any) -> Iterable[Tuple[Any, Any]]:
 
 
 @singledispatch
-def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any:
+def deserialize_container(
+        template: ArrayContainerT,
+        iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT:
     """Deserialize an iterable into an array container.
 
     :param template: an instance of an existing object that
@@ -214,7 +235,8 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]:
 # {{{ object arrays as array containers
 
 @serialize_container.register(np.ndarray)
-def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
+def _serialize_ndarray_container(
+        ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]:
     if ary.dtype.char != "O":
         raise NotAnArrayContainerError(
                 f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'")
@@ -232,9 +254,10 @@ def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
 
 
 @deserialize_container.register(np.ndarray)
-def _deserialize_ndarray_container(
-        template: np.ndarray,
-        iterable: Iterable[Tuple[Any, Any]]) -> np.ndarray:
+# https://github.com/python/mypy/issues/13040
+def _deserialize_ndarray_container(  # type: ignore[misc]
+        template: numpy.ndarray,
+        iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray:
     # disallow subclasses
     assert type(template) is np.ndarray
     assert template.dtype.char == "O"
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 2911340..6c858fe 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -39,6 +39,8 @@ Algebraic operations
 .. autofunction:: outer
 """
 
+from __future__ import annotations
+
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
 """
@@ -63,15 +65,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Any, Callable, Iterable, List, Optional, Union, Tuple
+from typing import Any, Callable, Iterable, List, Optional, Union, Tuple, cast
 from functools import update_wrapper, partial, singledispatch
 from warnings import warn
 
 import numpy as np
 
-from arraycontext.context import ArrayContext, Array, _ScalarLike
+from arraycontext.context import (
+    ArrayT, ArrayOrContainer, ArrayOrContainerT,
+    ArrayOrContainerOrScalar, ScalarLike,
+    ArrayContext, Array
+)
 from arraycontext.container import (
-        ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
+        NotAnArrayContainerError,
+        ArrayContainer,
         serialize_container, deserialize_container,
         get_container_context_recursively_opt)
 
@@ -79,10 +86,10 @@ from arraycontext.container import (
 # {{{ array container traversal helpers
 
 def _map_array_container_impl(
-        f: Callable[[Any], Any],
-        ary: ArrayOrContainerT, *,
+        f: Callable[[ArrayOrContainer], ArrayOrContainer],
+        ary: ArrayOrContainer, *,
         leaf_cls: Optional[type] = None,
-        recursive: bool = False) -> ArrayOrContainerT:
+        recursive: bool = False) -> ArrayOrContainer:
     """Helper for :func:`rec_map_array_container`.
 
     :param leaf_cls: class on which we call *f* directly. This is mostly
@@ -90,7 +97,7 @@ def _map_array_container_impl(
         specific container classes. By default, the recursion is stopped when
         a non-:class:`ArrayContainer` class is encountered.
     """
-    def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
+    def rec(_ary: ArrayOrContainer) -> ArrayOrContainer:
         if type(_ary) is leaf_cls:  # type(ary) is never None
             return f(_ary)
 
@@ -110,9 +117,10 @@ def _map_array_container_impl(
 def _multimap_array_container_impl(
         f: Callable[..., Any],
         *args: Any,
-        reduce_func: Callable[[ContainerT, Iterable[Tuple[Any, Any]]], Any] = None,
+        reduce_func: Optional[Callable[
+            [ArrayContainer, Iterable[Tuple[Any, Any]]], Any]] = None,
         leaf_cls: Optional[type] = None,
-        recursive: bool = False) -> ArrayOrContainerT:
+        recursive: bool = False) -> ArrayOrContainer:
     """Helper for :func:`rec_multimap_array_container`.
 
     :param leaf_cls: class on which we call *f* directly. This is mostly
@@ -198,7 +206,7 @@ def _multimap_array_container_impl(
             return f(*new_args)
 
         update_wrapper(wrapper, f)
-        template_ary: ContainerT = args[container_indices[0]]
+        template_ary: ArrayContainer = args[container_indices[0]]
         return _map_array_container_impl(
                 wrapper, template_ary,
                 leaf_cls=leaf_cls, recursive=recursive)
@@ -221,7 +229,7 @@ def _multimap_array_container_impl(
 
 def map_array_container(
         f: Callable[[Any], Any],
-        ary: ArrayOrContainerT) -> ArrayOrContainerT:
+        ary: ArrayOrContainer) -> ArrayOrContainer:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but
@@ -259,8 +267,8 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
 
 def rec_map_array_container(
         f: Callable[[Any], Any],
-        ary: ArrayOrContainerT,
-        leaf_class: Optional[type] = None) -> ArrayOrContainerT:
+        ary: ArrayOrContainer,
+        leaf_class: Optional[type] = None) -> ArrayOrContainer:
     r"""Applies *f* recursively to an :class:`ArrayContainer`.
 
     For a non-recursive version see :func:`map_array_container`.
@@ -272,15 +280,15 @@ def rec_map_array_container(
 
 
 def mapped_over_array_containers(
-        f: Optional[Callable[[Any], Any]] = None,
+        f: Optional[Callable[[ArrayOrContainer], ArrayOrContainer]] = None,
         leaf_class: Optional[type] = None) -> Union[
-            Callable[[ArrayOrContainerT], ArrayOrContainerT],
+            Callable[[ArrayOrContainer], ArrayOrContainer],
             Callable[
                 [Callable[[Any], Any]],
-                Callable[[ArrayOrContainerT], ArrayOrContainerT]]]:
+                Callable[[ArrayOrContainer], ArrayOrContainer]]]:
     """Decorator around :func:`rec_map_array_container`."""
-    def decorator(g: Callable[[Any], Any]) -> Callable[
-            [ArrayOrContainerT], ArrayOrContainerT]:
+    def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[
+            [ArrayOrContainer], ArrayOrContainer]:
         wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
         update_wrapper(wrapper, g)
         return wrapper
@@ -329,11 +337,14 @@ def multimapped_over_array_containers(
 
 # {{{ keyed array container traversal
 
+KeyType = Any
+
+
 def keyed_map_array_container(
         f: Callable[
-            [Any, ArrayOrContainerT],
-            ArrayOrContainerT],
-        ary: ArrayOrContainerT) -> ArrayOrContainerT:
+            [KeyType, ArrayOrContainer],
+            ArrayOrContainer],
+        ary: ArrayOrContainer) -> ArrayOrContainer:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`map_array_container`, but *f* also takes an
@@ -356,8 +367,8 @@ def keyed_map_array_container(
 
 
 def rec_keyed_map_array_container(
-        f: Callable[[Tuple[Any, ...], ArrayT], ArrayT],
-        ary: ArrayOrContainerT) -> ArrayOrContainerT:
+        f: Callable[[Tuple[KeyType, ...], ArrayT], ArrayT],
+        ary: ArrayOrContainer) -> ArrayOrContainer:
     """
     Works similarly to :func:`rec_map_array_container`, except that *f* also
     takes in a traversal path to the leaf array. The traversal path argument is
@@ -370,7 +381,7 @@ def rec_keyed_map_array_container(
         try:
             iterable = serialize_container(_ary)
         except NotAnArrayContainerError:
-            return f(keys, _ary)
+            return cast(ArrayOrContainerT, f(keys, cast(ArrayT, _ary)))
         else:
             return deserialize_container(_ary, [
                 (key, rec(keys + (key,), subary)) for key, subary in iterable
@@ -409,7 +420,7 @@ def map_reduce_array_container(
 def multimap_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[..., Any],
-        *args: Any) -> "Array":
+        *args: Any) -> ArrayOrContainer:
     r"""Perform a map-reduce over multiple array containers.
 
     :param reduce_func: callable used to reduce over the components of any
@@ -421,7 +432,9 @@ def multimap_reduce_array_container(
     """
     # NOTE: this wrapper matches the signature of `deserialize_container`
     # to make plugging into `_multimap_array_container_impl` easier
-    def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any:
+    def _reduce_wrapper(
+            ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]
+            ) -> Array:
         return reduce_func([subary for _, subary in iterable])
 
     return _multimap_array_container_impl(
@@ -432,8 +445,8 @@ def multimap_reduce_array_container(
 def rec_map_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[[Any], Any],
-        ary: ArrayOrContainerT,
-        leaf_class: Optional[type] = None) -> "Array":
+        ary: ArrayOrContainer,
+        leaf_class: Optional[type] = None) -> ArrayOrContainer:
     """Perform a map-reduce over array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of *ary*
@@ -491,7 +504,7 @@ def rec_multimap_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[..., Any],
         *args: Any,
-        leaf_class: Optional[type] = None) -> "Array":
+        leaf_class: Optional[type] = None) -> ArrayOrContainer:
     r"""Perform a map-reduce over multiple array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of any
@@ -509,7 +522,8 @@ def rec_multimap_reduce_array_container(
     """
     # NOTE: this wrapper matches the signature of `deserialize_container`
     # to make plugging into `_multimap_array_container_impl` easier
-    def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any:
+    def _reduce_wrapper(
+            ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]) -> Any:
         return reduce_func([subary for _, subary in iterable])
 
     return _multimap_array_container_impl(
@@ -623,7 +637,7 @@ def with_array_context(ary: ArrayOrContainerT,
 # {{{ flatten / unflatten
 
 def flatten(
-        ary: ArrayOrContainerT, actx: ArrayContext, *,
+        ary: ArrayOrContainer, actx: ArrayContext, *,
         leaf_class: Optional[type] = None,
         ) -> Any:
     """Convert all arrays in the :class:`~arraycontext.ArrayContainer`
@@ -646,32 +660,35 @@ def flatten(
     """
     common_dtype = None
 
-    def _flatten(subary: ArrayOrContainerT) -> List[Any]:
+    def _flatten(subary: ArrayOrContainer) -> List[Array]:
         nonlocal common_dtype
 
         try:
             iterable = serialize_container(subary)
         except NotAnArrayContainerError:
+            subary_c = cast(Array, subary)
+
             if common_dtype is None:
-                common_dtype = subary.dtype
+                common_dtype = subary_c.dtype
 
-            if subary.dtype != common_dtype:
+            if subary_c.dtype != common_dtype:
                 raise ValueError("arrays in container have different dtypes: "
-                        f"got {subary.dtype}, expected {common_dtype}")
+                        f"got {subary_c.dtype}, expected {common_dtype}")
 
             try:
-                flat_subary = actx.np.ravel(subary, order="C")
+                flat_subary = actx.np.ravel(subary_c, order="C")
             except ValueError as exc:
                 # NOTE: we can't do much if the array context fails to ravel,
                 # since it is the one responsible for the actual memory layout
-                if hasattr(subary, "strides"):
-                    strides_msg = f" and strides {subary.strides}"
+                if hasattr(subary_c, "strides"):
+                    # Mypy has a point: nobody promised a strides attr.
+                    strides_msg = f" and strides {subary_c.strides}"  # type: ignore[attr-defined]  # noqa: E501
                 else:
                     strides_msg = ""
 
                 raise NotImplementedError(
                         f"'{type(actx).__name__}.np.ravel' failed to reshape "
-                        f"an array with shape {subary.shape}{strides_msg}. "
+                        f"an array with shape {subary_c.shape}{strides_msg}. "
                         "This functionality needs to be implemented by the "
                         "array context.") from exc
 
@@ -683,7 +700,7 @@ def flatten(
 
         return result
 
-    def _flatten_without_leaf_class(subary: ArrayOrContainerT) -> Any:
+    def _flatten_without_leaf_class(subary: ArrayOrContainer) -> Any:
         result = _flatten(subary)
 
         if len(result) == 1:
@@ -691,7 +708,7 @@ def flatten(
         else:
             return actx.np.concatenate(result)
 
-    def _flatten_with_leaf_class(subary: ArrayOrContainerT) -> Any:
+    def _flatten_with_leaf_class(subary: ArrayOrContainer) -> Any:
         if type(subary) is leaf_class:
             return _flatten_without_leaf_class(subary)
 
@@ -731,46 +748,48 @@ def unflatten(
     offset = 0
     common_dtype = None
 
-    def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
+    def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
         nonlocal offset, common_dtype
 
         try:
             iterable = serialize_container(template_subary)
         except NotAnArrayContainerError:
+            template_subary_c = cast(Array, template_subary)
+
             # {{{ validate subary
 
-            if (offset + template_subary.size) > ary.size:
+            if (offset + template_subary_c.size) > ary.size:
                 raise ValueError("'template' and 'ary' sizes do not match: "
                     "'template' is too large")
 
             if strict:
-                if template_subary.dtype != ary.dtype:
+                if template_subary_c.dtype != ary.dtype:
                     raise ValueError("'template' dtype does not match 'ary': "
-                            f"got {template_subary.dtype}, expected {ary.dtype}")
+                            f"got {template_subary_c.dtype}, expected {ary.dtype}")
             else:
                 # NOTE: still require that *template* has a uniform dtype
                 if common_dtype is None:
-                    common_dtype = template_subary.dtype
+                    common_dtype = template_subary_c.dtype
                 else:
-                    if common_dtype != template_subary.dtype:
+                    if common_dtype != template_subary_c.dtype:
                         raise ValueError("arrays in 'template' have different "
-                                f"dtypes: got {template_subary.dtype}, but "
+                                f"dtypes: got {template_subary_c.dtype}, but "
                                 f"expected {common_dtype}.")
 
             # }}}
 
             # {{{ reshape
 
-            flat_subary = ary[offset:offset + template_subary.size]
+            flat_subary = ary[offset:offset + template_subary_c.size]
             try:
                 subary = actx.np.reshape(flat_subary,
-                        template_subary.shape, order="C")
+                        template_subary_c.shape, order="C")
             except ValueError as exc:
                 # NOTE: we can't do much if the array context fails to reshape,
                 # since it is the one responsible for the actual memory layout
                 raise NotImplementedError(
                         f"'{type(actx).__name__}.np.reshape' failed to reshape "
-                        f"the flat array into shape {template_subary.shape}. "
+                        f"the flat array into shape {template_subary_c.shape}. "
                         "This functionality needs to be implemented by the "
                         "array context.") from exc
 
@@ -782,21 +801,23 @@ def unflatten(
                 # Checking strides for 0 sized arrays is ill-defined
                 # since they cannot be indexed
                 if (
-                    template_subary.strides != subary.strides
-                    and template_subary.size != 0
+                    # Mypy has a point: nobody promised a .strides attribute.
+                    template_subary_c.strides != subary.strides  # type: ignore[attr-defined]  # noqa: E501
+                    and template_subary_c.size != 0
                 ):
                     raise ValueError(
-                            f"strides do not match template: got {subary.strides}, "
-                            f"expected {template_subary.strides}")
+                            # Mypy has a point: nobody promised a .strides attribute.
+                            f"strides do not match template: got {subary.strides}, "   # type: ignore[attr-defined]  # noqa: E501
+                            f"expected {template_subary_c.strides}")
 
             # }}}
 
-            offset += template_subary.size
+            offset += template_subary_c.size
             return subary
         else:
             return deserialize_container(template_subary, [
-                (key, _unflatten(isubary)) for key, isubary in iterable
-                ])
+                        (key, _unflatten(isubary)) for key, isubary in iterable
+                        ])
 
     if not isinstance(ary, actx.array_types):
         raise TypeError("'ary' does not have a type supported by the provided "
@@ -813,11 +834,11 @@ def unflatten(
         raise ValueError("'template' and 'ary' sizes do not match: "
             "'ary' is too large")
 
-    return result
+    return cast(ArrayOrContainerT, result)
 
 
 def flat_size_and_dtype(
-        ary: ArrayOrContainerT) -> "Tuple[int, Optional[np.dtype[Any]]]":
+        ary: ArrayOrContainer) -> "Tuple[int, Optional[np.dtype[Any]]]":
     """
     :returns: a tuple ``(size, dtype)`` that would be the length and
         :class:`numpy.dtype` of the one-dimensional array returned by
@@ -825,20 +846,22 @@ def flat_size_and_dtype(
     """
     common_dtype = None
 
-    def _flat_size(subary: ArrayOrContainerT) -> int:
+    def _flat_size(subary: ArrayOrContainer) -> int:
         nonlocal common_dtype
 
         try:
             iterable = serialize_container(subary)
         except NotAnArrayContainerError:
+            subary_c = cast(Array, subary)
+
             if common_dtype is None:
-                common_dtype = subary.dtype
+                common_dtype = subary_c.dtype
 
-            if subary.dtype != common_dtype:
+            if subary_c.dtype != common_dtype:
                 raise ValueError("arrays in container have different dtypes: "
-                        f"got {subary.dtype}, expected {common_dtype}")
+                        f"got {subary_c.dtype}, expected {common_dtype}")
 
-            return subary.size
+            return subary_c.size
         else:
             return sum(_flat_size(isubary) for _, isubary in iterable)
 
@@ -851,15 +874,15 @@ def flat_size_and_dtype(
 # {{{ numpy conversion
 
 def from_numpy(
-        ary: Union[np.ndarray, _ScalarLike],
-        actx: ArrayContext) -> ArrayOrContainerT:
+        ary: Union[np.ndarray, ScalarLike],
+        actx: ArrayContext) -> ArrayOrContainerOrScalar:
     """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
     to the base array type of :class:`~arraycontext.ArrayContext`.
 
     The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
     """
-    def _from_numpy_with_check(subary: Union[np.ndarray, _ScalarLike]) \
-            -> ArrayOrContainerT:
+    def _from_numpy_with_check(subary: Union[np.ndarray, ScalarLike]) \
+            -> ArrayOrContainerOrScalar:
         if isinstance(subary, np.ndarray) or np.isscalar(subary):
             return actx.from_numpy(subary)
         else:
@@ -868,7 +891,7 @@ def from_numpy(
     return rec_map_array_container(_from_numpy_with_check, ary)
 
 
-def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
+def to_numpy(ary: ArrayOrContainer, actx: ArrayContext) -> Any:
     """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to
     :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*.
 
diff --git a/arraycontext/context.py b/arraycontext/context.py
index b6278e2..74d5863 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -72,16 +72,43 @@ actual array contexts:
     an array expression that has been built up by the user
     (using, e.g. :func:`pytato.generate_loopy`).
 
-The interface of an array context
----------------------------------
 
 .. currentmodule:: arraycontext
 
-.. autoclass:: Array
-.. autoclass:: Scalar
+The interface of an array context
+---------------------------------
+
 .. autoclass:: ArrayContext
+
 .. autofunction:: tag_axes
 
+Types and Type Variables for Arrays and Containers
+--------------------------------------------------
+
+.. autoclass:: Array
+
+.. class:: ArrayT
+
+    A type variable with a lower bound of :class:`Array`.
+
+.. class:: ScalarLike
+
+    A type annotation for scalar types commonly usable with arrays.
+
+See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`.
+
+.. class:: ArrayOrContainer
+
+.. class:: ArrayOrContainerT
+
+    A type variable with a lower bound of :class:`ArrayOrContainer`.
+
+.. class:: ArrayOrContainerOrScalar
+
+.. class:: ArrayOrContainerOrScalarT
+
+    A type variable with a lower bound of :class:`ArrayOrContainerOrScalar`.
+
 Internal typing helpers (do not import)
 ---------------------------------------
 
@@ -91,11 +118,21 @@ This is only here because the documentation tool wants it.
 
 .. class:: SelfType
 
+Canonical locations for type annotations
+----------------------------------------
+
 .. class:: ArrayT
 
-    A type variable, with a lower bound of :class:`Array`.
-"""
+    :canonical: arraycontext.ArrayT
+
+.. class:: ArrayOrContainerT
+
+    :canonical: arraycontext.ArrayOrContainerT
 
+.. class:: ArrayOrContainerOrScalarT
+
+    :canonical: arraycontext.ArrayOrContainerOrScalarT
+"""
 
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -132,11 +169,12 @@ from pytools.tag import ToTagSetConvertible
 
 if TYPE_CHECKING:
     import loopy
+    from arraycontext.container import ArrayContainer
 
 
 # {{{ typing
 
-_ScalarLike = Union[int, float, complex, np.generic]
+ScalarLike = Union[int, float, complex, np.generic]
 
 try:
     from typing import Protocol
@@ -154,6 +192,7 @@ class Array(Protocol):
     supported types see :attr:`ArrayContext.array_types`.
 
     .. attribute:: shape
+    .. attribute:: size
     .. attribute:: dtype
     """
 
@@ -162,28 +201,7 @@ class Array(Protocol):
         ...
 
     @property
-    def dtype(self) -> "np.dtype[Any]":
-        ...
-
-
-ArrayT = TypeVar("ArrayT", bound=Array)
-
-
-class Scalar(Protocol):
-    """A :class:`~typing.Protocol` for the scalar type supported by
-    :class:`ArrayContext`.
-
-    In :mod:`numpy` terminology, this is just an array with a shape of ``()``.
-
-    This is meant to aid in typing annotations. For a explicit list of
-    supported types see :attr:`ArrayContext.array_types`.
-
-    .. attribute:: shape
-    .. attribute:: dtype
-    """
-
-    @property
-    def shape(self) -> Tuple[()]:
+    def size(self) -> int:
         ...
 
     @property
@@ -191,6 +209,18 @@ class Scalar(Protocol):
         ...
 
 
+# deprecated, use ScalarLike instead
+Scalar = ScalarLike
+
+
+ArrayT = TypeVar("ArrayT", bound=Array)
+ArrayOrContainer = Union[Array, "ArrayContainer"]
+ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer)
+ArrayOrContainerOrScalar = Union[Array, "ArrayContainer", ScalarLike]
+ArrayOrContainerOrScalarT = TypeVar(
+        "ArrayOrContainerOrScalarT",
+        bound=ArrayOrContainerOrScalar)
+
 # }}}
 
 
@@ -273,8 +303,8 @@ class ArrayContext(ABC):
 
     @abstractmethod
     def from_numpy(self,
-                   array: Union["np.ndarray[Any, Any]", _ScalarLike]
-                   ) -> Union[Array, _ScalarLike]:
+                   array: Union["np.ndarray[Any, Any]", ScalarLike]
+                   ) -> Union[Array, ScalarLike]:
         r"""
         :returns: the :class:`numpy.ndarray` *array* converted to the
             array context's array type. The returned array will be
@@ -284,8 +314,8 @@ class ArrayContext(ABC):
 
     @abstractmethod
     def to_numpy(self,
-                 array: Union[Array, _ScalarLike]
-                 ) -> Union["np.ndarray[Any, Any]", _ScalarLike]:
+                 array: Union[Array, ScalarLike]
+                 ) -> Union["np.ndarray[Any, Any]", ScalarLike]:
         r"""
         :returns: *array*, an array recognized by the context, converted
             to a :class:`numpy.ndarray`. *array* must be
@@ -293,6 +323,7 @@ class ArrayContext(ABC):
         """
         pass
 
+    @abstractmethod
     def call_loopy(self,
                    program: "loopy.TranslationUnit",
                    **kwargs: Any) -> Dict[str, Array]:
@@ -308,7 +339,7 @@ class ArrayContext(ABC):
         """
 
     @abstractmethod
-    def freeze(self, array: Array) -> Array:
+    def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
         """Return a version of the context-defined array *array* that is
         'frozen', i.e. suitable for long-term storage and reuse. Frozen arrays
         do not support arithmetic. For example, in the context of
@@ -324,7 +355,7 @@ class ArrayContext(ABC):
         """
 
     @abstractmethod
-    def thaw(self, array: Array) -> Array:
+    def thaw(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
         """Take a 'frozen' array and return a new array representing the data in
         *array* that is able to perform arithmetic and other operations, using
         the execution resources of this context. In the context of
diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py
index 154300a..81076ae 100644
--- a/arraycontext/impl/jax/__init__.py
+++ b/arraycontext/impl/jax/__init__.py
@@ -31,7 +31,7 @@ import numpy as np
 
 from typing import Union, Callable, Any
 from pytools.tag import ToTagSetConvertible
-from arraycontext.context import ArrayContext, _ScalarLike
+from arraycontext.context import ArrayContext, ScalarLike
 from arraycontext.container.traversal import (with_array_context,
                                               rec_map_array_container)
 
@@ -70,7 +70,7 @@ class EagerJAXArrayContext(ArrayContext):
         import jax.numpy as jnp
         return jnp.zeros(shape=shape, dtype=dtype)
 
-    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
+    def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
         import jax
         return jax.device_put(array)
 
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index de0b27b..1467246 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -35,7 +35,7 @@ import numpy as np
 
 from pytools.tag import ToTagSetConvertible
 
-from arraycontext.context import ArrayContext, _ScalarLike
+from arraycontext.context import ArrayContext, ScalarLike
 from arraycontext.container.traversal import (rec_map_array_container,
                                               with_array_context)
 
@@ -167,7 +167,7 @@ class PyOpenCLArrayContext(ArrayContext):
                                                  allocator=self.allocator),
                                   axes=None, tags=frozenset())
 
-    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
+    def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
         import pyopencl.array as cl_array
         from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
         return to_tagged_cl_array(cl_array
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 0970ffb..ed3bff8 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -43,7 +43,7 @@ THE SOFTWARE.
 """
 
 import sys
-from arraycontext.context import ArrayContext, _ScalarLike
+from arraycontext.context import ArrayContext, ScalarLike
 from arraycontext.container.traversal import (rec_map_array_container,
                                               with_array_context)
 from arraycontext.metadata import NameHint
@@ -237,7 +237,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
     def clone(self):
         return type(self)(self.queue, self.allocator)
 
-    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
+    def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
         import pytato as pt
         import pyopencl.array as cla
         cl_array = cla.to_device(self.queue, array)
@@ -288,7 +288,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
         import pyopencl.array as cla
         import loopy as lp
 
-        from arraycontext.container import ArrayT
+        from arraycontext.context import ArrayT
         from arraycontext.container.traversal import rec_keyed_map_array_container
         from arraycontext.impl.pytato.utils import (_normalize_pt_expr,
                                                     get_cl_axes_from_pt_axes)
@@ -524,7 +524,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
     def clone(self):
         return type(self)()
 
-    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
+    def from_numpy(self, array: Union[np.ndarray, ScalarLike]):
         import jax
         import pytato as pt
         return pt.make_data_wrapper(jax.device_put(array))
@@ -548,7 +548,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
         import pytato as pt
 
         from jax.numpy import DeviceArray
-        from arraycontext.container import ArrayT
+        from arraycontext.context import ArrayT
         from arraycontext.container.traversal import rec_keyed_map_array_container
         from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
 
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 46404e2..9e92adf 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -29,8 +29,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from arraycontext.container import (ArrayContainer, is_array_container_type,
-                                    ArrayT)
+from arraycontext.context import ArrayT
+from arraycontext.container import ArrayContainer, is_array_container_type
 from arraycontext.impl.pytato import (_BasePytatoArrayContext,
                                       PytatoJAXArrayContext,
                                       PytatoPyOpenCLArrayContext)
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 928f446..e89663c 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -174,7 +174,8 @@ def _serialize_dof_container(ary: DOFArray):
 
 
 @deserialize_container.register(DOFArray)
-def _deserialize_dof_container(
+# https://github.com/python/mypy/issues/13040
+def _deserialize_dof_container(  # type: ignore[misc]
         template, iterable):
     def _raise_index_inconsistency(i, stream_i):
         raise ValueError(
@@ -189,7 +190,8 @@ def _deserialize_dof_container(
 
 
 @with_array_context.register(DOFArray)
-def _with_actx_dofarray(ary, actx):
+# https://github.com/python/mypy/issues/13040
+def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray:  # type: ignore[misc]  # noqa: E501
     return type(ary)(actx, ary.data)
 
 # }}}
@@ -1188,7 +1190,8 @@ class Velocity2D:
 
 
 @with_array_context.register(Velocity2D)
-def _with_actx_velocity_2d(ary, actx):
+# https://github.com/python/mypy/issues/13040
+def _with_actx_velocity_2d(ary, actx):  # type: ignore[misc]
     return type(ary)(ary.u, ary.v, actx)
 
 
-- 
GitLab