From e309ea20beee1527a76764cf087755ee9d8bd449 Mon Sep 17 00:00:00 2001
From: Matthew Smith <mjsmith6@illinois.edu>
Date: Wed, 25 Aug 2021 16:46:35 -0500
Subject: [PATCH] split ArrayContainerT into ContainerT and ArrayOrContainerT

---
 arraycontext/container/__init__.py  | 12 ++++++---
 arraycontext/container/traversal.py | 40 ++++++++++++++---------------
 2 files changed, 29 insertions(+), 23 deletions(-)

diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 5142c05..1046f99 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -3,11 +3,16 @@
 """
 .. currentmodule:: arraycontext
 
-.. class:: ArrayContainerT
-    :canonical: arraycontext.container.ArrayContainerT
+.. class:: ContainerT
+    :canonical: arraycontext.container.ContainerT
 
     :class:`~typing.TypeVar` for array container-like objects.
 
+.. class:: ArrayOrContainerT
+    :canonical: arraycontext.container.ArrayOrContainerT
+
+    :class:`~typing.TypeVar` for arrays or array container-like objects.
+
 .. autoclass:: ArrayContainer
 
 Serialization/deserialization
@@ -52,7 +57,8 @@ from arraycontext.context import ArrayContext
 from typing import Any, Iterable, Tuple, TypeVar, Optional
 import numpy as np
 
-ArrayContainerT = TypeVar("ArrayContainerT")
+ContainerT = TypeVar("ContainerT")
+ArrayOrContainerT = TypeVar("ArrayOrContainerT")
 
 
 # {{{ ArrayContainer
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index cfdea4c..ed6dc0c 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -58,7 +58,7 @@ import numpy as np
 
 from arraycontext.context import ArrayContext
 from arraycontext.container import (
-        ArrayContainerT, is_array_container,
+        ContainerT, ArrayOrContainerT, is_array_container,
         serialize_container, deserialize_container)
 
 
@@ -66,9 +66,9 @@ from arraycontext.container import (
 
 def _map_array_container_impl(
         f: Callable[[Any], Any],
-        ary: ArrayContainerT, *,
+        ary: ArrayOrContainerT, *,
         leaf_cls: Optional[type] = None,
-        recursive: bool = False) -> ArrayContainerT:
+        recursive: bool = False) -> ArrayOrContainerT:
     """Helper for :func:`rec_map_array_container`.
 
     :param leaf_cls: class on which we call *f* directly. This is mostly
@@ -76,7 +76,7 @@ def _map_array_container_impl(
         specific container classes. By default, the recursion is stopped when
         a non-:class:`ArrayContainer` class is encountered.
     """
-    def rec(_ary: ArrayContainerT) -> ArrayContainerT:
+    def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
         if type(_ary) is leaf_cls:  # type(ary) is never None
             return f(_ary)
         elif is_array_container(_ary):
@@ -93,9 +93,9 @@ def _map_array_container_impl(
 def _multimap_array_container_impl(
         f: Callable[..., Any],
         *args: Any,
-        reduce_func: Callable[[Any, Iterable[Tuple[Any, Any]]], Any] = None,
+        reduce_func: Callable[[ContainerT, Iterable[Tuple[Any, Any]]], Any] = None,
         leaf_cls: Optional[type] = None,
-        recursive: bool = False) -> ArrayContainerT:
+        recursive: bool = False) -> ArrayOrContainerT:
     """Helper for :func:`rec_multimap_array_container`.
 
     :param leaf_cls: class on which we call *f* directly. This is mostly
@@ -142,13 +142,13 @@ def _multimap_array_container_impl(
     if len(container_indices) == 1 and reduce_func is None:
         # NOTE: if we just have one ArrayContainer in args, passing it through
         # _map_array_container_impl should be faster
-        def wrapper(ary: ArrayContainerT) -> ArrayContainerT:
+        def wrapper(ary: ContainerT) -> ContainerT:
             new_args = list(args)
             new_args[container_indices[0]] = ary
             return f(*new_args)
 
         update_wrapper(wrapper, f)
-        template_ary: ArrayContainerT = args[container_indices[0]]
+        template_ary: ContainerT = args[container_indices[0]]
         return _map_array_container_impl(
                 wrapper, template_ary,
                 leaf_cls=leaf_cls, recursive=recursive)
@@ -165,7 +165,7 @@ def _multimap_array_container_impl(
 
 def map_array_container(
         f: Callable[[Any], Any],
-        ary: ArrayContainerT) -> ArrayContainerT:
+        ary: ArrayOrContainerT) -> ArrayOrContainerT:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but
@@ -202,7 +202,7 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
 
 def rec_map_array_container(
         f: Callable[[Any], Any],
-        ary: ArrayContainerT) -> ArrayContainerT:
+        ary: ArrayOrContainerT) -> ArrayOrContainerT:
     r"""Applies *f* recursively to an :class:`ArrayContainer`.
 
     For a non-recursive version see :func:`map_array_container`.
@@ -214,7 +214,7 @@ def rec_map_array_container(
 
 
 def mapped_over_array_containers(
-        f: Callable[[Any], Any]) -> Callable[[ArrayContainerT], ArrayContainerT]:
+        f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
     """Decorator around :func:`rec_map_array_container`."""
     wrapper = partial(rec_map_array_container, f)
     update_wrapper(wrapper, f)
@@ -249,7 +249,7 @@ def multimapped_over_array_containers(
 # {{{ keyed array container traversal
 
 def keyed_map_array_container(f: Callable[[Any, Any], Any],
-                              ary: ArrayContainerT) -> ArrayContainerT:
+                              ary: ArrayOrContainerT) -> ArrayOrContainerT:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`map_array_container`, but *f* also takes an
@@ -269,7 +269,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any],
 
 
 def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
-                                  ary: ArrayContainerT) -> ArrayContainerT:
+                                  ary: ArrayOrContainerT) -> ArrayOrContainerT:
     """
     Works similarly to :func:`rec_map_array_container`, except that *f* also
     takes in a traversal path to the leaf array. The traversal path argument is
@@ -278,7 +278,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
     """
 
     def rec(keys: Tuple[Union[str, int], ...],
-            _ary: ArrayContainerT) -> ArrayContainerT:
+            _ary: ArrayOrContainerT) -> ArrayOrContainerT:
         if is_array_container(_ary):
             return deserialize_container(_ary, [
                     (key, rec(keys + (key,), subary))
@@ -297,7 +297,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
 def rec_map_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[[Any], Any],
-        ary: ArrayContainerT) -> Any:
+        ary: ArrayOrContainerT) -> Any:
     """Perform a map-reduce over array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of the
@@ -307,7 +307,7 @@ def rec_map_reduce_array_container(
         type :class:`arraycontext.ArrayContext.array_types` and returns an
         array of the same type or a scalar.
     """
-    def rec(_ary: ArrayContainerT) -> ArrayContainerT:
+    def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
         if is_array_container(_ary):
             return reduce_func([
                 rec(subary) for _, subary in serialize_container(_ary)
@@ -333,7 +333,7 @@ def rec_multimap_reduce_array_container(
     """
     # NOTE: this wrapper matches the signature of `deserialize_container`
     # to make plugging into `_multimap_array_container_impl` easier
-    def _reduce_wrapper(ary: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any:
+    def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any:
         return reduce_func([subary for _, subary in iterable])
 
     return _multimap_array_container_impl(
@@ -347,8 +347,8 @@ def rec_multimap_reduce_array_container(
 
 @singledispatch
 def freeze(
-        ary: ArrayContainerT,
-        actx: Optional[ArrayContext] = None) -> ArrayContainerT:
+        ary: ArrayOrContainerT,
+        actx: Optional[ArrayContext] = None) -> ArrayOrContainerT:
     r"""Freezes recursively by going through all components of the
     :class:`ArrayContainer` *ary*.
 
@@ -372,7 +372,7 @@ def freeze(
 
 
 @singledispatch
-def thaw(ary: ArrayContainerT, actx: ArrayContext) -> ArrayContainerT:
+def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
     r"""Thaws recursively by going through all components of the
     :class:`ArrayContainer` *ary*.
 
-- 
GitLab