From 8559234ed64fa394f5851bbba47b008002772fd3 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Mon, 25 Oct 2021 19:54:33 -0500 Subject: [PATCH] better error message in dataclass_array_container --- arraycontext/container/__init__.py | 4 ++ arraycontext/container/dataclass.py | 29 ++++++++++++-- test/test_utils.py | 59 ++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 6 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 72bd024..ca13935 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -173,6 +173,10 @@ def is_array_container_type(cls: type) -> bool: function will say that :class:`numpy.ndarray` is an array container type, only object arrays *actually are* array containers. """ + assert isinstance(cls, type), \ + f"must pass a type, not an instance: '{cls!r}'" + assert hasattr(cls, "__mro__"), "'cls' has no attribute '__mro__': " + return ( cls is ArrayContainer or (serialize_container.dispatch(cls) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 2891f60..bf433f2 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -49,10 +49,31 @@ def dataclass_array_container(cls: type) -> type: from dataclasses import is_dataclass assert is_dataclass(cls) - array_fields = [ - f for f in fields(cls) if is_array_container_type(f.type)] - non_array_fields = [ - f for f in fields(cls) if not is_array_container_type(f.type)] + def is_array_field(f): + if __debug__: + if not f.init: + raise ValueError( + f"'init=False' field not allowed: '{f.name}'") + + if isinstance(f.type, str): + raise TypeError( + f"string annotation on field '{f.name}' not supported") + + from typing import _SpecialForm + if isinstance(f.type, _SpecialForm): + raise TypeError( + f"typing annotation not supported on field '{f.name}': " + f"'{f.type!r}'") + + if not isinstance(f.type, type): + raise TypeError( + f"field '{f.name}' not an instance of 'type': " + f"'{f.type!r}'") + + return is_array_container_type(f.type) + + from pytools import partition + array_fields, non_array_fields = partition(is_array_field, fields(cls)) if not array_fields: raise ValueError(f"'{cls}' must have fields with array container type " diff --git a/test/test_utils.py b/test/test_utils.py index 2228152..08b6c3a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -22,11 +22,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import pytest + +import numpy as np import logging logger = logging.getLogger(__name__) +# {{{ test_pt_actx_key_stringification_uniqueness + def test_pt_actx_key_stringification_uniqueness(): from arraycontext.impl.pytato.compile import _ary_container_key_stringifier @@ -36,13 +41,63 @@ def test_pt_actx_key_stringification_uniqueness(): assert (_ary_container_key_stringifier(("tup", 3, "endtup")) != _ary_container_key_stringifier(((3,),))) +# }}} + + +# {{{ test_dataclass_array_container + +def test_dataclass_array_container(): + from typing import Optional + from dataclasses import dataclass, field + from arraycontext import dataclass_array_container + + # {{{ string fields + + @dataclass + class ArrayContainerWithStringTypes: + x: np.ndarray + y: "np.ndarray" + + with pytest.raises(TypeError): + # NOTE: cannot have string annotations in container + dataclass_array_container(ArrayContainerWithStringTypes) + + # }}} + + # {{{ optional fields + + @dataclass + class ArrayContainerWithOptional: + x: np.ndarray + y: Optional[np.ndarray] + + with pytest.raises(TypeError): + # NOTE: cannot have wrapped annotations (here by `Optional`) + dataclass_array_container(ArrayContainerWithOptional) + + # }}} + + # {{{ field(init=False) + + @dataclass + class ArrayContainerWithInitFalse: + x: np.ndarray + y: np.ndarray = field(default=np.zeros(42), init=False, repr=False) + + with pytest.raises(ValueError): + # NOTE: init=False fields are not allowed + dataclass_array_container(ArrayContainerWithInitFalse) + + # }}} + +# }}} + if __name__ == "__main__": import sys if len(sys.argv) > 1: exec(sys.argv[1]) else: - from pytest import main - main([__file__]) + pytest.main([__file__]) # vim: fdm=marker -- GitLab