From 22f9ead841b2358f8859293e7595cb451f715c8e Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Fri, 15 Nov 2024 12:13:11 +0200
Subject: [PATCH] feat: improve dataclass container

---
 arraycontext/container/dataclass.py |  8 +++--
 test/test_utils.py                  | 55 ++++++++++++++++++++++++-----
 2 files changed, 53 insertions(+), 10 deletions(-)

diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index ec4c37f..ae9ab48 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -59,6 +59,8 @@ def dataclass_array_container(cls: type) -> type:
       array containers, even if they wrap one.
     """
 
+    from types import GenericAlias, UnionType
+
     assert is_dataclass(cls)
 
     def is_array_field(f: Field) -> bool:
@@ -75,7 +77,8 @@ def dataclass_array_container(cls: type) -> type:
         # This is not set in stone, but mostly driven by current usage!
 
         origin = get_origin(f.type)
-        if origin is Union:
+        # NOTE: `UnionType` is returned when using `Type1 | Type2`
+        if origin in (Union, UnionType):
             if all(is_array_type(arg) for arg in get_args(f.type)):
                 return True
             else:
@@ -94,13 +97,14 @@ def dataclass_array_container(cls: type) -> type:
                         f"Field with 'init=False' not allowed: '{f.name}'")
 
             # NOTE:
+            # * `GenericAlias` catches typed `list`, `tuple`, etc.
             # * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
             # * `_SpecialForm` catches `Any`, `Literal`, etc.
             from typing import (  # type: ignore[attr-defined]
                 _BaseGenericAlias,
                 _SpecialForm,
             )
-            if isinstance(f.type, _BaseGenericAlias | _SpecialForm):
+            if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm):
                 # NOTE: anything except a Union is not allowed
                 raise TypeError(
                         f"Typing annotation not supported on field '{f.name}': "
diff --git a/test/test_utils.py b/test/test_utils.py
index 04817d6..db9ed82 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -49,9 +49,9 @@ def test_pt_actx_key_stringification_uniqueness():
 
 def test_dataclass_array_container() -> None:
     from dataclasses import dataclass, field
-    from typing import Optional
+    from typing import Optional, Tuple  # noqa: UP035
 
-    from arraycontext import dataclass_array_container
+    from arraycontext import Array, dataclass_array_container
 
     # {{{ string fields
 
@@ -60,7 +60,7 @@ def test_dataclass_array_container() -> None:
         x: np.ndarray
         y: "np.ndarray"
 
-    with pytest.raises(TypeError):
+    with pytest.raises(TypeError, match="String annotation on field 'y'"):
         # NOTE: cannot have string annotations in container
         dataclass_array_container(ArrayContainerWithStringTypes)
 
@@ -73,12 +73,32 @@ def test_dataclass_array_container() -> None:
         x: np.ndarray
         y: Optional[np.ndarray]
 
-    with pytest.raises(TypeError):
+    with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
         # NOTE: cannot have wrapped annotations (here by `Optional`)
         dataclass_array_container(ArrayContainerWithOptional)
 
     # }}}
 
+    # {{{ type annotations
+
+    @dataclass
+    class ArrayContainerWithTuple:
+        x: Array
+        y: Tuple[Array, Array]
+
+    with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"):
+        dataclass_array_container(ArrayContainerWithTuple)
+
+    @dataclass
+    class ArrayContainerWithTupleAlt:
+        x: Array
+        y: tuple[Array, Array]
+
+    with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"):
+        dataclass_array_container(ArrayContainerWithTupleAlt)
+
+    # }}}
+
     # {{{ field(init=False)
 
     @dataclass
@@ -87,7 +107,7 @@ def test_dataclass_array_container() -> None:
         y: np.ndarray = field(default_factory=lambda: np.zeros(42),
                               init=False, repr=False)
 
-    with pytest.raises(ValueError):
+    with pytest.raises(ValueError, match="Field with 'init=False' not allowed"):
         # NOTE: init=False fields are not allowed
         dataclass_array_container(ArrayContainerWithInitFalse)
 
@@ -95,8 +115,6 @@ def test_dataclass_array_container() -> None:
 
     # {{{ device arrays
 
-    from arraycontext import Array
-
     @dataclass
     class ArrayContainerWithArray:
         x: Array
@@ -126,6 +144,13 @@ def test_dataclass_container_unions() -> None:
 
     dataclass_array_container(ArrayContainerWithUnion)
 
+    @dataclass
+    class ArrayContainerWithUnionAlt:
+        x: np.ndarray
+        y: np.ndarray | Array
+
+    dataclass_array_container(ArrayContainerWithUnionAlt)
+
     # }}}
 
     # {{{ non-container union
@@ -135,12 +160,26 @@ def test_dataclass_container_unions() -> None:
         x: np.ndarray
         y: Union[np.ndarray, float]
 
-    with pytest.raises(TypeError):
+    with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
         # NOTE: float is not an ArrayContainer, so y should fail
         dataclass_array_container(ArrayContainerWithWrongUnion)
 
     # }}}
 
+    # {{{ optional union
+
+    @dataclass
+    class ArrayContainerWithOptionalUnion:
+        x: np.ndarray
+        y: np.ndarray | None
+
+    with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
+        # NOTE: None is not an ArrayContainer, so y should fail
+        dataclass_array_container(ArrayContainerWithWrongUnion)
+
+    # }}}
+
+
 # }}}
 
 
-- 
GitLab