From b6866767d658063bc38d7e94d1593b2c545d8d6f Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 27 Jun 2022 10:20:07 +0300
Subject: [PATCH] split dataclass_array_container for easier modification

---
 arraycontext/container/dataclass.py | 69 +++++++++++++++++++++++------
 1 file changed, 56 insertions(+), 13 deletions(-)

diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index edbb450..4f60abd 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -30,19 +30,24 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Union, get_args
+from typing import Tuple, Union, get_args
 try:
     # NOTE: only available in python >= 3.8
     from typing import get_origin
 except ImportError:
     from typing_extensions import get_origin
 
-from dataclasses import fields
+from dataclasses import Field, is_dataclass, fields
 from arraycontext.container import is_array_container_type
 
 
 # {{{ dataclass containers
 
+def is_array_type(tp: type) -> bool:
+    from arraycontext import Array
+    return tp is Array or is_array_container_type(tp)
+
+
 def dataclass_array_container(cls: type) -> type:
     """A class decorator that makes the class to which it is applied an
     :class:`ArrayContainer` by registering appropriate implementations of
@@ -51,24 +56,37 @@ def dataclass_array_container(cls: type) -> type:
 
     Attributes that are not array containers are allowed. In order to decide
     whether an attribute is an array container, the declared attribute type
-    is checked by the criteria from :func:`is_array_container_type`.
+    is checked by the criteria from :func:`is_array_container_type`. This
+    includes some support for type annotations:
+
+    * a :class:`typing.Union` of array containers is considered an array container.
+    * other type annotations, e.g. :class:`typing.Optional`, are not considered
+      array containers, even if they wrap one.
     """
-    from dataclasses import is_dataclass, Field
+
     assert is_dataclass(cls)
 
     def is_array_field(f: Field) -> bool:
-        from arraycontext import Array
+        # NOTE: unions of array containers are treated separately to handle
+        # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
+        # they can work seamlessly with arithmetic and traversal.
+        #
+        # `Optional[ArrayContainer]` is not allowed, since `None` is not
+        # handled by `with_container_arithmetic`, which is the common case
+        # for current container usage. Other type annotations, e.g.
+        # `Tuple[Container, Container]`, are also not allowed, as they do not
+        # work with `with_container_arithmetic`.
+        #
+        # This is not set in stone, but mostly driven by current usage!
 
         origin = get_origin(f.type)
         if origin is Union:
-            if not all(
-                    arg is Array or is_array_container_type(arg)
-                    for arg in get_args(f.type)):
+            if all(is_array_type(arg) for arg in get_args(f.type)):
+                return True
+            else:
                 raise TypeError(
                         f"Field '{f.name}' union contains non-array container "
                         "arguments. All arguments must be array containers.")
-            else:
-                return True
 
         if __debug__:
             if not f.init:
@@ -79,8 +97,12 @@ def dataclass_array_container(cls: type) -> type:
                 raise TypeError(
                         f"string annotation on field '{f.name}' not supported")
 
-            from typing import _SpecialForm
-            if isinstance(f.type, _SpecialForm):
+            # NOTE:
+            # * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
+            # * `_SpecialForm` catches `Any`, `Literal`, etc.
+            from typing import (                    # type: ignore[attr-defined]
+                _BaseGenericAlias, _SpecialForm)
+            if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
                 # NOTE: anything except a Union is not allowed
                 raise TypeError(
                         f"typing annotation not supported on field '{f.name}': "
@@ -91,7 +113,7 @@ def dataclass_array_container(cls: type) -> type:
                         f"field '{f.name}' not an instance of 'type': "
                         f"'{f.type!r}'")
 
-        return f.type is Array or is_array_container_type(f.type)
+        return is_array_type(f.type)
 
     from pytools import partition
     array_fields, non_array_fields = partition(is_array_field, fields(cls))
@@ -100,6 +122,27 @@ def dataclass_array_container(cls: type) -> type:
         raise ValueError(f"'{cls}' must have fields with array container type "
                 "in order to use the 'dataclass_array_container' decorator")
 
+    return inject_dataclass_serialization(cls, array_fields, non_array_fields)
+
+
+def inject_dataclass_serialization(
+        cls: type,
+        array_fields: Tuple[Field, ...],
+        non_array_fields: Tuple[Field, ...]) -> type:
+    """Implements :func:`~arraycontext.serialize_container` and
+    :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
+
+    This function modifies *cls* in place, so the returned value is the same
+    object with additional functionality.
+
+    :arg array_fields: fields of the given dataclass *cls* which are considered
+        array containers and should be serialized.
+    :arg non_array_fields: remaining fields of the dataclass *cls* which are
+        copied over from the template array in deserialization.
+    """
+
+    assert is_dataclass(cls)
+
     serialize_expr = ", ".join(
             f"({f.name!r}, ary.{f.name})" for f in array_fields)
     template_kwargs = ", ".join(
-- 
GitLab