From 456d8933a2fa265cee30d5aaf34f903ddc1063eb Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 5 Apr 2022 10:04:24 -0500
Subject: [PATCH] allow dataclass containers with only DeviceArrays

---
 arraycontext/__init__.py            | 26 ++++++-----
 arraycontext/container/dataclass.py |  3 +-
 arraycontext/container/traversal.py | 10 ++---
 arraycontext/context.py             | 68 +++++++++++++++++++++++------
 doc/conf.py                         |  5 ---
 setup.py                            |  1 +
 test/test_utils.py                  | 13 ++++++
 7 files changed, 90 insertions(+), 36 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 8206fb8..e8f34d4 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -29,7 +29,7 @@ THE SOFTWARE.
 """
 
 import sys
-from .context import ArrayContext, DeviceArray, DeviceScalar
+from .context import ArrayContext, Array, Scalar
 
 from .transform_metadata import (CommonSubexpressionTag,
         ElementwiseMapKernelTag)
@@ -75,7 +75,7 @@ from .loopy import make_loopy_program
 
 
 __all__ = (
-        "ArrayContext", "DeviceScalar", "DeviceArray",
+        "ArrayContext", "Scalar", "Array",
 
         "CommonSubexpressionTag",
         "ElementwiseMapKernelTag",
@@ -125,24 +125,26 @@ def _deprecated_acf():
 
 
 _depr_name_to_replacement_and_obj = {
-        "get_container_context": ("get_container_context_opt",
-            get_container_context_opt),
-        "FirstAxisIsElementsTag":
-        ("meshmode.transform_metadata.FirstAxisIsElementsTag",
-            _FirstAxisIsElementsTag),
-        "_acf":
-        ("<no replacement yet>", _deprecated_acf),
+        "get_container_context": (
+            "get_container_context_opt",
+            get_container_context_opt, 2022),
+        "FirstAxisIsElementsTag": (
+            "meshmode.transform_metadata.FirstAxisIsElementsTag",
+            _FirstAxisIsElementsTag, 2022),
+        "_acf": ("<no replacement yet>", _deprecated_acf, 2022),
+        "DeviceArray": ("Array", Array, 2023),
+        "DeviceScalar": ("Scalar", Scalar, 2023),
         }
 
 if sys.version_info >= (3, 7):
     def __getattr__(name):
         replacement_and_obj = _depr_name_to_replacement_and_obj.get(name, None)
         if replacement_and_obj is not None:
-            replacement, obj = replacement_and_obj
+            replacement, obj, year = replacement_and_obj
             from warnings import warn
             warn(f"'arraycontext.{name}' is deprecated. "
                     f"Use '{replacement}' instead. "
-                    f"'arraycontext.{name}' will continue to work until 2022.",
+                    f"'arraycontext.{name}' will continue to work until {year}.",
                     DeprecationWarning, stacklevel=2)
             return obj
         else:
@@ -151,6 +153,8 @@ else:
     FirstAxisIsElementsTag = _FirstAxisIsElementsTag
     _acf = _deprecated_acf
     get_container_context = get_container_context_opt
+    DeviceArray = Array
+    DeviceScalar = Scalar
 
 # }}}
 
diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index 203246c..150d1d6 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -70,7 +70,8 @@ def dataclass_array_container(cls: type) -> type:
                         f"field '{f.name}' not an instance of 'type': "
                         f"'{f.type!r}'")
 
-        return is_array_container_type(f.type)
+        from arraycontext import Array
+        return f.type is Array or is_array_container_type(f.type)
 
     from pytools import partition
     array_fields, non_array_fields = partition(is_array_field, fields(cls))
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 23cec03..de89a6b 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -68,7 +68,7 @@ from functools import update_wrapper, partial, singledispatch
 
 import numpy as np
 
-from arraycontext.context import ArrayContext, DeviceArray, _ScalarLike
+from arraycontext.context import ArrayContext, Array, _ScalarLike
 from arraycontext.container import (
         ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
         serialize_container, deserialize_container)
@@ -384,7 +384,7 @@ def rec_keyed_map_array_container(
 def map_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[[Any], Any],
-        ary: ArrayOrContainerT) -> "DeviceArray":
+        ary: ArrayOrContainerT) -> "Array":
     """Perform a map-reduce over array containers.
 
     :param reduce_func: callable used to reduce over the components of *ary*
@@ -407,7 +407,7 @@ def map_reduce_array_container(
 def multimap_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[..., Any],
-        *args: Any) -> "DeviceArray":
+        *args: Any) -> "Array":
     r"""Perform a map-reduce over multiple array containers.
 
     :param reduce_func: callable used to reduce over the components of any
@@ -431,7 +431,7 @@ def rec_map_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[[Any], Any],
         ary: ArrayOrContainerT,
-        leaf_class: Optional[type] = None) -> "DeviceArray":
+        leaf_class: Optional[type] = None) -> "Array":
     """Perform a map-reduce over array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of *ary*
@@ -489,7 +489,7 @@ def rec_multimap_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[..., Any],
         *args: Any,
-        leaf_class: Optional[type] = None) -> "DeviceArray":
+        leaf_class: Optional[type] = None) -> "Array":
     r"""Perform a map-reduce over multiple array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of any
diff --git a/arraycontext/context.py b/arraycontext/context.py
index 6c13a33..d206a87 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -75,18 +75,8 @@ The interface of an array context
 
 .. currentmodule:: arraycontext
 
-.. class:: DeviceArray
-
-    A (type alias for an) array type supported by the :class:`ArrayContext`
-    meant to aid in typing annotations. For a explicit list of supported types
-    see :attr:`ArrayContext.array_types`.
-
-.. class:: DeviceScalar
-
-    A (type alias for a) scalar type supported by the :class:`ArrayContext`
-    meant to aid in typing annotations, e.g. for reductions. In :mod:`numpy`
-    terminology, this is just an array with a shape of ``()``.
-
+.. autoclass:: Array
+.. autoclass:: Scalar
 .. autoclass:: ArrayContext
 """
 
@@ -123,10 +113,60 @@ from pytools import memoize_method
 from pytools.tag import Tag
 
 
-DeviceArray = Any
-DeviceScalar = Any
+# {{{ typing
+
 _ScalarLike = Union[int, float, complex, np.generic]
 
+try:
+    from typing import Protocol
+except ImportError:
+    from typing_extensions import Protocol                  # type: ignore[misc]
+
+
+class Array(Protocol):
+    """A :class:`~typing.Protocol` for the array type supported by
+    :class:`ArrayContext`.
+
+    This is meant to aid in typing annotations. For a explicit list of
+    supported types see :attr:`ArrayContext.array_types`.
+
+    .. attribute:: shape
+    .. attribute:: dtype
+    """
+
+    @property
+    def shape(self) -> Tuple[int, ...]:
+        ...
+
+    @property
+    def dtype(self) -> "np.dtype[Any]":
+        ...
+
+
+class Scalar(Protocol):
+    """A :class:`~typing.Protocol` for the scalar type supported by
+    :class:`ArrayContext`.
+
+    In :mod:`numpy` terminology, this is just an array with a shape of ``()``.
+
+    This is meant to aid in typing annotations. For a explicit list of
+    supported types see :attr:`ArrayContext.array_types`.
+
+    .. attribute:: shape
+    .. attribute:: dtype
+    """
+
+    @property
+    def shape(self) -> Tuple[()]:
+        ...
+
+    @property
+    def dtype(self) -> "np.dtype[Any]":
+        ...
+
+
+# }}}
+
 
 # {{{ ArrayContext
 
diff --git a/doc/conf.py b/doc/conf.py
index 29f026e..bee0e10 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -14,11 +14,6 @@ exec(compile(open("../arraycontext/version.py").read(), "../arraycontext/version
 version = ".".join(str(x) for x in ver_dic["VERSION"])
 release = ver_dic["VERSION_TEXT"]
 
-autodoc_type_aliases = {
-        "DeviceScalar": "arraycontext.DeviceScalar",
-        "DeviceArray": "arraycontext.DeviceArray",
-        }
-
 intersphinx_mapping = {
     "https://docs.python.org/3/": None,
     "https://numpy.org/doc/stable/": None,
diff --git a/setup.py b/setup.py
index 8b0d677..06e898d 100644
--- a/setup.py
+++ b/setup.py
@@ -46,6 +46,7 @@ def main():
             "pytest>=2.3",
             "loopy>=2019.1",
             "dataclasses; python_version<'3.7'",
+            "typing_extensions; python_version<'3.8'",
             "types-dataclasses",
         ],
         package_data={"arraycontext": ["py.typed"]},
diff --git a/test/test_utils.py b/test/test_utils.py
index 08b6c3a..ac3127f 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -90,6 +90,19 @@ def test_dataclass_array_container():
 
     # }}}
 
+    # {{{ device arrays
+
+    from arraycontext import Array
+
+    @dataclass
+    class ArrayContainerWithArray:
+        x: Array
+        y: Array
+
+    dataclass_array_container(ArrayContainerWithArray)
+
+    # }}}
+
 # }}}
 
 
-- 
GitLab