diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 150d1d6838aae19c369126be75e70216e6f228a3..edbb45061c230d7d34d0e88d95896dd40ef8d6f4 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -30,6 +30,13 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import Union, get_args +try: + # NOTE: only available in python >= 3.8 + from typing import get_origin +except ImportError: + from typing_extensions import get_origin + from dataclasses import fields from arraycontext.container import is_array_container_type @@ -50,6 +57,19 @@ def dataclass_array_container(cls: type) -> type: assert is_dataclass(cls) def is_array_field(f: Field) -> bool: + from arraycontext import Array + + origin = get_origin(f.type) + if origin is Union: + if not all( + arg is Array or is_array_container_type(arg) + for arg in get_args(f.type)): + raise TypeError( + f"Field '{f.name}' union contains non-array container " + "arguments. All arguments must be array containers.") + else: + return True + if __debug__: if not f.init: raise ValueError( @@ -61,6 +81,7 @@ def dataclass_array_container(cls: type) -> type: from typing import _SpecialForm if isinstance(f.type, _SpecialForm): + # NOTE: anything except a Union is not allowed raise TypeError( f"typing annotation not supported on field '{f.name}': " f"'{f.type!r}'") @@ -70,7 +91,6 @@ def dataclass_array_container(cls: type) -> type: f"field '{f.name}' not an instance of 'type': " f"'{f.type!r}'") - from arraycontext import Array return f.type is Array or is_array_container_type(f.type) from pytools import partition diff --git a/setup.py b/setup.py index 2bc066ec81fb564ed1093f98231bae87973a3243..eb6421c29d28d2d00ce7e5e259285d178a1f12c3 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ def main(): "pytest>=2.3", "loopy>=2019.1", "dataclasses; python_version<'3.7'", - "typing_extensions; python_version<'3.8'", + "typing_extensions; python_version<'3.9'", "types-dataclasses", ], package_data={"arraycontext": ["py.typed"]}, diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index cd61120f0c7a8ca04347d3595ade5d8bf0dd5617..acf099716829400fd079fdf1dc0913c138a8dcd5 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -21,6 +21,8 @@ THE SOFTWARE. """ from dataclasses import dataclass +from typing import Union + import numpy as np import pytest @@ -678,9 +680,9 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec): @dataclass(frozen=True) class MyContainer: name: str - mass: DOFArray # or np.ndarray + mass: Union[DOFArray, np.ndarray] momentum: np.ndarray - enthalpy: DOFArray # or np.ndarray + enthalpy: Union[DOFArray, np.ndarray] @property def array_context(self): @@ -700,9 +702,9 @@ class MyContainer: @dataclass(frozen=True) class MyContainerDOFBcast: name: str - mass: DOFArray # or np.ndarray + mass: Union[DOFArray, np.ndarray] momentum: np.ndarray - enthalpy: DOFArray # or np.ndarray + enthalpy: Union[DOFArray, np.ndarray] @property def array_context(self): diff --git a/test/test_utils.py b/test/test_utils.py index ac3127fe8ce7843f5a7c513eb30baa6688e6e73a..7a12ad273c87469809d2d1c6d8626dfd434b8c77 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -106,6 +106,42 @@ def test_dataclass_array_container(): # }}} +# {{{ test_dataclass_container_unions + +def test_dataclass_container_unions(): + from dataclasses import dataclass + from arraycontext import dataclass_array_container + + from typing import Union + from arraycontext import Array + + # {{{ union fields + + @dataclass + class ArrayContainerWithUnion: + x: np.ndarray + y: Union[np.ndarray, Array] + + dataclass_array_container(ArrayContainerWithUnion) + + # }}} + + # {{{ non-container union + + @dataclass + class ArrayContainerWithWrongUnion: + x: np.ndarray + y: Union[np.ndarray, float] + + with pytest.raises(TypeError): + # NOTE: float is not an ArrayContainer, so y should fail + dataclass_array_container(ArrayContainerWithWrongUnion) + + # }}} + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: