From 1ea77f315bf4ab5bccd920dd387164c1d224f2e0 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 27 Nov 2024 12:59:36 -0600
Subject: [PATCH] dataclass_array_container: support string annotations

---
 arraycontext/container/dataclass.py |  55 +++++--
 pyproject.toml                      |   4 +
 test/test_arraycontext.py           | 182 +----------------------
 test/test_utils.py                  |  22 +--
 test/testlib.py                     | 216 ++++++++++++++++++++++++++++
 5 files changed, 270 insertions(+), 209 deletions(-)
 create mode 100644 test/testlib.py

diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index 9495b34..5ff9dfd 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -31,6 +31,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from collections.abc import Mapping, Sequence
 from dataclasses import Field, fields, is_dataclass
 from typing import Union, get_args, get_origin
 
@@ -58,13 +59,21 @@ def dataclass_array_container(cls: type) -> type:
     * a :class:`typing.Union` of array containers is considered an array container.
     * other type annotations, e.g. :class:`typing.Optional`, are not considered
       array containers, even if they wrap one.
+
+    .. note::
+
+        When type annotations are strings (e.g. because of
+        ``from __future__ import annotations``),
+        this function relies on :func:`inspect.get_annotations`
+        (with ``eval_str=True``) to obtain type annotations. This
+        means that *cls* must live in a module that is importable.
     """
 
     from types import GenericAlias, UnionType
 
     assert is_dataclass(cls)
 
-    def is_array_field(f: Field) -> bool:
+    def is_array_field(f: Field, field_type: type) -> bool:
         # NOTE: unions of array containers are treated separately to handle
         # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
         # they can work seamlessly with arithmetic and traversal.
@@ -77,17 +86,17 @@ def dataclass_array_container(cls: type) -> type:
         #
         # This is not set in stone, but mostly driven by current usage!
 
-        origin = get_origin(f.type)
+        origin = get_origin(field_type)
         # NOTE: `UnionType` is returned when using `Type1 | Type2`
         if origin in (Union, UnionType):
-            if all(is_array_type(arg) for arg in get_args(f.type)):
+            if all(is_array_type(arg) for arg in get_args(field_type)):
                 return True
             else:
                 raise TypeError(
                         f"Field '{f.name}' union contains non-array container "
                         "arguments. All arguments must be array containers.")
 
-        if isinstance(f.type, str):
+        if isinstance(field_type, str):
             raise TypeError(
                 f"String annotation on field '{f.name}' not supported. "
                 "(this may be due to 'from __future__ import annotations')")
@@ -105,33 +114,49 @@ def dataclass_array_container(cls: type) -> type:
                 _BaseGenericAlias,
                 _SpecialForm,
             )
-            if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm):
+            if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
                 # NOTE: anything except a Union is not allowed
                 raise TypeError(
                         f"Typing annotation not supported on field '{f.name}': "
-                        f"'{f.type!r}'")
+                        f"'{field_type!r}'")
 
-            if not isinstance(f.type, type):
+            if not isinstance(field_type, type):
                 raise TypeError(
                         f"Field '{f.name}' not an instance of 'type': "
-                        f"'{f.type!r}'")
+                        f"'{field_type!r}'")
+
+        return is_array_type(field_type)
+
+    from inspect import get_annotations
 
-        return is_array_type(f.type)
+    array_fields: list[Field] = []
+    non_array_fields: list[Field] = []
+    cls_ann: Mapping[str, type] | None = None
+    for field in fields(cls):
+        field_type_or_str = field.type
+        if isinstance(field_type_or_str, str):
+            if cls_ann is None:
+                cls_ann = get_annotations(cls, eval_str=True)
+            field_type = cls_ann[field.name]
+        else:
+            field_type = field_type_or_str
 
-    from pytools import partition
-    array_fields, non_array_fields = partition(is_array_field, fields(cls))
+        if is_array_field(field, field_type):
+            array_fields.append(field)
+        else:
+            non_array_fields.append(field)
 
     if not array_fields:
         raise ValueError(f"'{cls}' must have fields with array container type "
                 "in order to use the 'dataclass_array_container' decorator")
 
-    return inject_dataclass_serialization(cls, array_fields, non_array_fields)
+    return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
 
 
-def inject_dataclass_serialization(
+def _inject_dataclass_serialization(
         cls: type,
-        array_fields: tuple[Field, ...],
-        non_array_fields: tuple[Field, ...]) -> type:
+        array_fields: Sequence[Field],
+        non_array_fields: Sequence[Field]) -> type:
     """Implements :func:`~arraycontext.serialize_container` and
     :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
 
diff --git a/pyproject.toml b/pyproject.toml
index a9c1df4..d715981 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -97,12 +97,16 @@ known-first-party = [
 ]
 known-local-folder = [
     "arraycontext",
+    "testlib",
 ]
 lines-after-imports = 2
 required-imports = ["from __future__ import annotations"]
 
 [tool.ruff.lint.per-file-ignores]
 "doc/conf.py" = ["I002"]
+# To avoid a requirement of array container definitions being someplace importable
+# from @dataclass_array_container.
+"test/test_utils.py" = ["I002"]
 
 [tool.mypy]
 python_version = "3.10"
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 050bfc8..ab26330 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -34,18 +34,14 @@ from pytools.obj_array import make_obj_array
 from pytools.tag import Tag
 
 from arraycontext import (
-    ArrayContainer,
-    ArrayContext,
     EagerJAXArrayContext,
     NumpyArrayContext,
     PyOpenCLArrayContext,
     PytatoPyOpenCLArrayContext,
     dataclass_array_container,
-    deserialize_container,
     pytest_generate_tests_for_array_contexts,
     serialize_container,
     tag_axes,
-    with_array_context,
     with_container_arithmetic,
 )
 from arraycontext.pytest import (
@@ -55,6 +51,7 @@ from arraycontext.pytest import (
     _PytestPytatoJaxArrayContextFactory,
     _PytestPytatoPyOpenCLArrayContextFactory,
 )
+from testlib import DOFArray, MyContainer, MyContainerDOFBcast, Velocity2D
 
 
 logger = logging.getLogger(__name__)
@@ -116,147 +113,10 @@ def _acf():
 # }}}
 
 
-# {{{ stand-in DOFArray implementation
-
-@with_container_arithmetic(
-        bcasts_across_obj_array=True,
-        bitwise=True,
-        rel_comparison=True,
-        _cls_has_array_context_attr=True,
-        _bcast_actx_array_type=False)
-class DOFArray:
-    def __init__(self, actx, data):
-        if not (actx is None or isinstance(actx, ArrayContext)):
-            raise TypeError("actx must be of type ArrayContext")
-
-        if not isinstance(data, tuple):
-            raise TypeError("'data' argument must be a tuple")
-
-        self.array_context = actx
-        self.data = data
-
-    # prevent numpy broadcasting
-    __array_ufunc__ = None
-
-    def __bool__(self):
-        if len(self) == 1 and self.data[0].size == 1:
-            return bool(self.data[0])
-
-        raise ValueError(
-                "The truth value of an array with more than one element is "
-                "ambiguous. Use actx.np.any(x) or actx.np.all(x)")
-
-    def __len__(self):
-        return len(self.data)
-
-    def __getitem__(self, i):
-        return self.data[i]
-
-    def __repr__(self):
-        return f"DOFArray({self.data!r})"
-
-    @classmethod
-    def _serialize_init_arrays_code(cls, instance_name):
-        return {"_":
-                (f"{instance_name}_i", f"{instance_name}")}
-
-    @classmethod
-    def _deserialize_init_arrays_code(cls, template_instance_name, args):
-        (_, arg), = args.items()
-        # Why tuple([...])? https://stackoverflow.com/a/48592299
-        return (f"{template_instance_name}.array_context, tuple([{arg}])")
-
-    @property
-    def size(self):
-        return sum(ary.size for ary in self.data)
-
-    @property
-    def real(self):
-        return DOFArray(self.array_context, tuple(subary.real for subary in self))
-
-    @property
-    def imag(self):
-        return DOFArray(self.array_context, tuple(subary.imag for subary in self))
-
-
-@serialize_container.register(DOFArray)
-def _serialize_dof_container(ary: DOFArray):
-    return list(enumerate(ary.data))
-
-
-@deserialize_container.register(DOFArray)
-# https://github.com/python/mypy/issues/13040
-def _deserialize_dof_container(  # type: ignore[misc]
-        template, iterable):
-    def _raise_index_inconsistency(i, stream_i):
-        raise ValueError(
-                "out-of-sequence indices supplied in DOFArray deserialization "
-                f"(expected {i}, received {stream_i})")
-
-    return type(template)(
-            template.array_context,
-            data=tuple(
-                v if i == stream_i else _raise_index_inconsistency(i, stream_i)
-                for i, (stream_i, v) in enumerate(iterable)))
-
-
-@with_array_context.register(DOFArray)
-# https://github.com/python/mypy/issues/13040
-def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray:  # type: ignore[misc]
-    return type(ary)(actx, ary.data)
-
-# }}}
-
-
-# {{{ nested containers
-
-@with_container_arithmetic(bcasts_across_obj_array=False,
-        eq_comparison=False, rel_comparison=False,
-        _cls_has_array_context_attr=True,
-        _bcast_actx_array_type=False)
-@dataclass_array_container
-@dataclass(frozen=True)
-class MyContainer:
-    name: str
-    mass: DOFArray | np.ndarray
-    momentum: np.ndarray
-    enthalpy: DOFArray | np.ndarray
-
-    __array_ufunc__ = None
-
-    @property
-    def array_context(self):
-        if isinstance(self.mass, np.ndarray):
-            return next(iter(self.mass)).array_context
-        else:
-            return self.mass.array_context
-
-
-@with_container_arithmetic(
-        bcasts_across_obj_array=False,
-        bcast_container_types=(DOFArray, np.ndarray),
-        matmul=True,
-        rel_comparison=True,
-        _cls_has_array_context_attr=True,
-        _bcast_actx_array_type=False)
-@dataclass_array_container
-@dataclass(frozen=True)
-class MyContainerDOFBcast:
-    name: str
-    mass: DOFArray | np.ndarray
-    momentum: np.ndarray
-    enthalpy: DOFArray | np.ndarray
-
-    @property
-    def array_context(self):
-        if isinstance(self.mass, np.ndarray):
-            return next(iter(self.mass)).array_context
-        else:
-            return self.mass.array_context
-
-
 def _get_test_containers(actx, ambient_dim=2, shapes=50_000):
     from numbers import Number
+
+    from testlib import DOFArray, MyContainer, MyContainerDOFBcast
     if isinstance(shapes, Number | tuple):
         shapes = [shapes]
 
@@ -286,8 +146,6 @@ def _get_test_containers(actx, ambient_dim=2, shapes=50_000):
     return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs,
             bcast_dataclass_of_dofs)
 
-# }}}
-
 
 # {{{ assert_close_to_numpy*
 
@@ -1224,21 +1082,6 @@ def test_norm_ord_none(actx_factory, ndim):
 
 # {{{ test_actx_compile helpers
 
-@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True)
-@dataclass_array_container
-@dataclass(frozen=True)
-class Velocity2D:
-    u: ArrayContainer
-    v: ArrayContainer
-    array_context: ArrayContext
-
-
-@with_array_context.register(Velocity2D)
-# https://github.com/python/mypy/issues/13040
-def _with_actx_velocity_2d(ary, actx):  # type: ignore[misc]
-    return type(ary)(ary.u, ary.v, actx)
-
-
 def scale_and_orthogonalize(alpha, vel):
     from arraycontext import rec_map_array_container
     actx = vel.array_context
@@ -1353,25 +1196,8 @@ def test_container_equality(actx_factory):
 
 # {{{ test_no_leaf_array_type_broadcasting
 
-@with_container_arithmetic(
-    bcasts_across_obj_array=True,
-    rel_comparison=True,
-    _cls_has_array_context_attr=True,
-    _bcast_actx_array_type=False)
-@dataclass_array_container
-@dataclass(frozen=True)
-class Foo:
-    u: DOFArray
-
-    # prevent numpy arithmetic from taking precedence
-    __array_ufunc__ = None
-
-    @property
-    def array_context(self):
-        return self.u.array_context
-
-
 def test_no_leaf_array_type_broadcasting(actx_factory):
+    from testlib import Foo
     # test lack of support for https://github.com/inducer/arraycontext/issues/49
     actx = actx_factory()
 
diff --git a/test/test_utils.py b/test/test_utils.py
index a6aa271..807d652 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,7 +1,8 @@
 """Testing for internal  utilities."""
-from __future__ import annotations
 
-from typing import cast
+# Do not add
+# from __future__ import annotations
+# to allow the non-string annotations below to work.
 
 
 __copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees"
@@ -26,6 +27,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 import logging
+from typing import Optional, cast
 
 import numpy as np
 import pytest
@@ -55,25 +57,13 @@ def test_dataclass_array_container() -> None:
 
     from arraycontext import Array, dataclass_array_container
 
-    # {{{ string fields
-
-    @dataclass
-    class ArrayContainerWithStringTypes:
-        x: np.ndarray
-        y: np.ndarray
-
-    with pytest.raises(TypeError, match="String annotation on field 'y'"):
-        # NOTE: cannot have string annotations in container
-        dataclass_array_container(ArrayContainerWithStringTypes)
-
-    # }}}
-
     # {{{ optional fields
 
     @dataclass
     class ArrayContainerWithOptional:
         x: np.ndarray
-        y: np.ndarray | None
+        # Deliberately left as Optional to test compatibility.
+        y: Optional[np.ndarray]  # noqa: UP007
 
     with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
         # NOTE: cannot have wrapped annotations (here by `Optional`)
diff --git a/test/testlib.py b/test/testlib.py
new file mode 100644
index 0000000..3f08520
--- /dev/null
+++ b/test/testlib.py
@@ -0,0 +1,216 @@
+from __future__ import annotations
+
+
+__copyright__ = "Copyright (C) 2020-21 University of Illinois Board of Trustees"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+from dataclasses import dataclass
+
+import numpy as np
+
+from arraycontext import (
+    ArrayContainer,
+    ArrayContext,
+    dataclass_array_container,
+    deserialize_container,
+    serialize_container,
+    with_array_context,
+    with_container_arithmetic,
+)
+
+
+# Containers live here, because in order for get_annotations to work, they must
+# live somewhere importable.
+# See https://docs.python.org/3.12/library/inspect.html#inspect.get_annotations
+
+
+# {{{ stand-in DOFArray implementation
+
+@with_container_arithmetic(
+        bcasts_across_obj_array=True,
+        bitwise=True,
+        rel_comparison=True,
+        _cls_has_array_context_attr=True,
+        _bcast_actx_array_type=False)
+class DOFArray:
+    def __init__(self, actx, data):
+        if not (actx is None or isinstance(actx, ArrayContext)):
+            raise TypeError("actx must be of type ArrayContext")
+
+        if not isinstance(data, tuple):
+            raise TypeError("'data' argument must be a tuple")
+
+        self.array_context = actx
+        self.data = data
+
+    # prevent numpy broadcasting
+    __array_ufunc__ = None
+
+    def __bool__(self):
+        if len(self) == 1 and self.data[0].size == 1:
+            return bool(self.data[0])
+
+        raise ValueError(
+                "The truth value of an array with more than one element is "
+                "ambiguous. Use actx.np.any(x) or actx.np.all(x)")
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, i):
+        return self.data[i]
+
+    def __repr__(self):
+        return f"DOFArray({self.data!r})"
+
+    @classmethod
+    def _serialize_init_arrays_code(cls, instance_name):
+        return {"_":
+                (f"{instance_name}_i", f"{instance_name}")}
+
+    @classmethod
+    def _deserialize_init_arrays_code(cls, template_instance_name, args):
+        (_, arg), = args.items()
+        # Why tuple([...])? https://stackoverflow.com/a/48592299
+        return (f"{template_instance_name}.array_context, tuple([{arg}])")
+
+    @property
+    def size(self):
+        return sum(ary.size for ary in self.data)
+
+    @property
+    def real(self):
+        return DOFArray(self.array_context, tuple(subary.real for subary in self))
+
+    @property
+    def imag(self):
+        return DOFArray(self.array_context, tuple(subary.imag for subary in self))
+
+
+@serialize_container.register(DOFArray)
+def _serialize_dof_container(ary: DOFArray):
+    return list(enumerate(ary.data))
+
+
+@deserialize_container.register(DOFArray)
+# https://github.com/python/mypy/issues/13040
+def _deserialize_dof_container(  # type: ignore[misc]
+        template, iterable):
+    def _raise_index_inconsistency(i, stream_i):
+        raise ValueError(
+                "out-of-sequence indices supplied in DOFArray deserialization "
+                f"(expected {i}, received {stream_i})")
+
+    return type(template)(
+            template.array_context,
+            data=tuple(
+                v if i == stream_i else _raise_index_inconsistency(i, stream_i)
+                for i, (stream_i, v) in enumerate(iterable)))
+
+
+@with_array_context.register(DOFArray)
+# https://github.com/python/mypy/issues/13040
+def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray:  # type: ignore[misc]
+    return type(ary)(actx, ary.data)
+
+# }}}
+
+
+# {{{ nested containers
+
+@with_container_arithmetic(bcasts_across_obj_array=False,
+        eq_comparison=False, rel_comparison=False,
+        _cls_has_array_context_attr=True,
+        _bcast_actx_array_type=False)
+@dataclass_array_container
+@dataclass(frozen=True)
+class MyContainer:
+    name: str
+    mass: DOFArray | np.ndarray
+    momentum: np.ndarray
+    enthalpy: DOFArray | np.ndarray
+
+    __array_ufunc__ = None
+
+    @property
+    def array_context(self):
+        if isinstance(self.mass, np.ndarray):
+            return next(iter(self.mass)).array_context
+        else:
+            return self.mass.array_context
+
+
+@with_container_arithmetic(
+        bcasts_across_obj_array=False,
+        bcast_container_types=(DOFArray, np.ndarray),
+        matmul=True,
+        rel_comparison=True,
+        _cls_has_array_context_attr=True,
+        _bcast_actx_array_type=False)
+@dataclass_array_container
+@dataclass(frozen=True)
+class MyContainerDOFBcast:
+    name: str
+    mass: DOFArray | np.ndarray
+    momentum: np.ndarray
+    enthalpy: DOFArray | np.ndarray
+
+    @property
+    def array_context(self):
+        if isinstance(self.mass, np.ndarray):
+            return next(iter(self.mass)).array_context
+        else:
+            return self.mass.array_context
+
+# }}}
+
+
+@with_container_arithmetic(
+    bcasts_across_obj_array=True,
+    rel_comparison=True,
+    _cls_has_array_context_attr=True,
+    _bcast_actx_array_type=False)
+@dataclass_array_container
+@dataclass(frozen=True)
+class Foo:
+    u: DOFArray
+
+    # prevent numpy arithmetic from taking precedence
+    __array_ufunc__ = None
+
+    @property
+    def array_context(self):
+        return self.u.array_context
+
+
+@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True)
+@dataclass_array_container
+@dataclass(frozen=True)
+class Velocity2D:
+    u: ArrayContainer
+    v: ArrayContainer
+    array_context: ArrayContext
+
+
+@with_array_context.register(Velocity2D)
+# https://github.com/python/mypy/issues/13040
+def _with_actx_velocity_2d(ary, actx):  # type: ignore[misc]
+    return type(ary)(ary.u, ary.v, actx)
-- 
GitLab