From 8559234ed64fa394f5851bbba47b008002772fd3 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 25 Oct 2021 19:54:33 -0500
Subject: [PATCH] better error message in dataclass_array_container

---
 arraycontext/container/__init__.py  |  4 ++
 arraycontext/container/dataclass.py | 29 ++++++++++++--
 test/test_utils.py                  | 59 ++++++++++++++++++++++++++++-
 3 files changed, 86 insertions(+), 6 deletions(-)

diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 72bd024..ca13935 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -173,6 +173,10 @@ def is_array_container_type(cls: type) -> bool:
         function will say that :class:`numpy.ndarray` is an array container
         type, only object arrays *actually are* array containers.
     """
+    assert isinstance(cls, type), \
+            f"must pass a type, not an instance: '{cls!r}'"
+    assert hasattr(cls, "__mro__"), "'cls' has no attribute '__mro__': "
+
     return (
             cls is ArrayContainer
             or (serialize_container.dispatch(cls)
diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index 2891f60..bf433f2 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -49,10 +49,31 @@ def dataclass_array_container(cls: type) -> type:
     from dataclasses import is_dataclass
     assert is_dataclass(cls)
 
-    array_fields = [
-            f for f in fields(cls) if is_array_container_type(f.type)]
-    non_array_fields = [
-            f for f in fields(cls) if not is_array_container_type(f.type)]
+    def is_array_field(f):
+        if __debug__:
+            if not f.init:
+                raise ValueError(
+                        f"'init=False' field not allowed: '{f.name}'")
+
+            if isinstance(f.type, str):
+                raise TypeError(
+                        f"string annotation on field '{f.name}' not supported")
+
+            from typing import _SpecialForm
+            if isinstance(f.type, _SpecialForm):
+                raise TypeError(
+                        f"typing annotation not supported on field '{f.name}': "
+                        f"'{f.type!r}'")
+
+            if not isinstance(f.type, type):
+                raise TypeError(
+                        f"field '{f.name}' not an instance of 'type': "
+                        f"'{f.type!r}'")
+
+        return is_array_container_type(f.type)
+
+    from pytools import partition
+    array_fields, non_array_fields = partition(is_array_field, fields(cls))
 
     if not array_fields:
         raise ValueError(f"'{cls}' must have fields with array container type "
diff --git a/test/test_utils.py b/test/test_utils.py
index 2228152..08b6c3a 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -22,11 +22,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
+import pytest
+
+import numpy as np
 
 import logging
 logger = logging.getLogger(__name__)
 
 
+# {{{ test_pt_actx_key_stringification_uniqueness
+
 def test_pt_actx_key_stringification_uniqueness():
     from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
 
@@ -36,13 +41,63 @@ def test_pt_actx_key_stringification_uniqueness():
     assert (_ary_container_key_stringifier(("tup", 3, "endtup"))
             != _ary_container_key_stringifier(((3,),)))
 
+# }}}
+
+
+# {{{ test_dataclass_array_container
+
+def test_dataclass_array_container():
+    from typing import Optional
+    from dataclasses import dataclass, field
+    from arraycontext import dataclass_array_container
+
+    # {{{ string fields
+
+    @dataclass
+    class ArrayContainerWithStringTypes:
+        x: np.ndarray
+        y: "np.ndarray"
+
+    with pytest.raises(TypeError):
+        # NOTE: cannot have string annotations in container
+        dataclass_array_container(ArrayContainerWithStringTypes)
+
+    # }}}
+
+    # {{{ optional fields
+
+    @dataclass
+    class ArrayContainerWithOptional:
+        x: np.ndarray
+        y: Optional[np.ndarray]
+
+    with pytest.raises(TypeError):
+        # NOTE: cannot have wrapped annotations (here by `Optional`)
+        dataclass_array_container(ArrayContainerWithOptional)
+
+    # }}}
+
+    # {{{ field(init=False)
+
+    @dataclass
+    class ArrayContainerWithInitFalse:
+        x: np.ndarray
+        y: np.ndarray = field(default=np.zeros(42), init=False, repr=False)
+
+    with pytest.raises(ValueError):
+        # NOTE: init=False fields are not allowed
+        dataclass_array_container(ArrayContainerWithInitFalse)
+
+    # }}}
+
+# }}}
+
 
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
         exec(sys.argv[1])
     else:
-        from pytest import main
-        main([__file__])
+        pytest.main([__file__])
 
 # vim: fdm=marker
-- 
GitLab