From 1ea77f315bf4ab5bccd920dd387164c1d224f2e0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 27 Nov 2024 12:59:36 -0600 Subject: [PATCH] dataclass_array_container: support string annotations --- arraycontext/container/dataclass.py | 55 +++++-- pyproject.toml | 4 + test/test_arraycontext.py | 182 +---------------------- test/test_utils.py | 22 +-- test/testlib.py | 216 ++++++++++++++++++++++++++++ 5 files changed, 270 insertions(+), 209 deletions(-) create mode 100644 test/testlib.py diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 9495b34..5ff9dfd 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -31,6 +31,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from collections.abc import Mapping, Sequence from dataclasses import Field, fields, is_dataclass from typing import Union, get_args, get_origin @@ -58,13 +59,21 @@ def dataclass_array_container(cls: type) -> type: * a :class:`typing.Union` of array containers is considered an array container. * other type annotations, e.g. :class:`typing.Optional`, are not considered array containers, even if they wrap one. + + .. note:: + + When type annotations are strings (e.g. because of + ``from __future__ import annotations``), + this function relies on :func:`inspect.get_annotations` + (with ``eval_str=True``) to obtain type annotations. This + means that *cls* must live in a module that is importable. """ from types import GenericAlias, UnionType assert is_dataclass(cls) - def is_array_field(f: Field) -> bool: + def is_array_field(f: Field, field_type: type) -> bool: # NOTE: unions of array containers are treated separately to handle # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as # they can work seamlessly with arithmetic and traversal. @@ -77,17 +86,17 @@ def dataclass_array_container(cls: type) -> type: # # This is not set in stone, but mostly driven by current usage! - origin = get_origin(f.type) + origin = get_origin(field_type) # NOTE: `UnionType` is returned when using `Type1 | Type2` if origin in (Union, UnionType): - if all(is_array_type(arg) for arg in get_args(f.type)): + if all(is_array_type(arg) for arg in get_args(field_type)): return True else: raise TypeError( f"Field '{f.name}' union contains non-array container " "arguments. All arguments must be array containers.") - if isinstance(f.type, str): + if isinstance(field_type, str): raise TypeError( f"String annotation on field '{f.name}' not supported. " "(this may be due to 'from __future__ import annotations')") @@ -105,33 +114,49 @@ def dataclass_array_container(cls: type) -> type: _BaseGenericAlias, _SpecialForm, ) - if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm): + if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm): # NOTE: anything except a Union is not allowed raise TypeError( f"Typing annotation not supported on field '{f.name}': " - f"'{f.type!r}'") + f"'{field_type!r}'") - if not isinstance(f.type, type): + if not isinstance(field_type, type): raise TypeError( f"Field '{f.name}' not an instance of 'type': " - f"'{f.type!r}'") + f"'{field_type!r}'") + + return is_array_type(field_type) + + from inspect import get_annotations - return is_array_type(f.type) + array_fields: list[Field] = [] + non_array_fields: list[Field] = [] + cls_ann: Mapping[str, type] | None = None + for field in fields(cls): + field_type_or_str = field.type + if isinstance(field_type_or_str, str): + if cls_ann is None: + cls_ann = get_annotations(cls, eval_str=True) + field_type = cls_ann[field.name] + else: + field_type = field_type_or_str - from pytools import partition - array_fields, non_array_fields = partition(is_array_field, fields(cls)) + if is_array_field(field, field_type): + array_fields.append(field) + else: + non_array_fields.append(field) if not array_fields: raise ValueError(f"'{cls}' must have fields with array container type " "in order to use the 'dataclass_array_container' decorator") - return inject_dataclass_serialization(cls, array_fields, non_array_fields) + return _inject_dataclass_serialization(cls, array_fields, non_array_fields) -def inject_dataclass_serialization( +def _inject_dataclass_serialization( cls: type, - array_fields: tuple[Field, ...], - non_array_fields: tuple[Field, ...]) -> type: + array_fields: Sequence[Field], + non_array_fields: Sequence[Field]) -> type: """Implements :func:`~arraycontext.serialize_container` and :func:`~arraycontext.deserialize_container` for the given dataclass *cls*. diff --git a/pyproject.toml b/pyproject.toml index a9c1df4..d715981 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,12 +97,16 @@ known-first-party = [ ] known-local-folder = [ "arraycontext", + "testlib", ] lines-after-imports = 2 required-imports = ["from __future__ import annotations"] [tool.ruff.lint.per-file-ignores] "doc/conf.py" = ["I002"] +# To avoid a requirement of array container definitions being someplace importable +# from @dataclass_array_container. +"test/test_utils.py" = ["I002"] [tool.mypy] python_version = "3.10" diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 050bfc8..ab26330 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -34,18 +34,14 @@ from pytools.obj_array import make_obj_array from pytools.tag import Tag from arraycontext import ( - ArrayContainer, - ArrayContext, EagerJAXArrayContext, NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, - deserialize_container, pytest_generate_tests_for_array_contexts, serialize_container, tag_axes, - with_array_context, with_container_arithmetic, ) from arraycontext.pytest import ( @@ -55,6 +51,7 @@ from arraycontext.pytest import ( _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, ) +from testlib import DOFArray, MyContainer, MyContainerDOFBcast, Velocity2D logger = logging.getLogger(__name__) @@ -116,147 +113,10 @@ def _acf(): # }}} -# {{{ stand-in DOFArray implementation - -@with_container_arithmetic( - bcasts_across_obj_array=True, - bitwise=True, - rel_comparison=True, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -class DOFArray: - def __init__(self, actx, data): - if not (actx is None or isinstance(actx, ArrayContext)): - raise TypeError("actx must be of type ArrayContext") - - if not isinstance(data, tuple): - raise TypeError("'data' argument must be a tuple") - - self.array_context = actx - self.data = data - - # prevent numpy broadcasting - __array_ufunc__ = None - - def __bool__(self): - if len(self) == 1 and self.data[0].size == 1: - return bool(self.data[0]) - - raise ValueError( - "The truth value of an array with more than one element is " - "ambiguous. Use actx.np.any(x) or actx.np.all(x)") - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - return self.data[i] - - def __repr__(self): - return f"DOFArray({self.data!r})" - - @classmethod - def _serialize_init_arrays_code(cls, instance_name): - return {"_": - (f"{instance_name}_i", f"{instance_name}")} - - @classmethod - def _deserialize_init_arrays_code(cls, template_instance_name, args): - (_, arg), = args.items() - # Why tuple([...])? https://stackoverflow.com/a/48592299 - return (f"{template_instance_name}.array_context, tuple([{arg}])") - - @property - def size(self): - return sum(ary.size for ary in self.data) - - @property - def real(self): - return DOFArray(self.array_context, tuple(subary.real for subary in self)) - - @property - def imag(self): - return DOFArray(self.array_context, tuple(subary.imag for subary in self)) - - -@serialize_container.register(DOFArray) -def _serialize_dof_container(ary: DOFArray): - return list(enumerate(ary.data)) - - -@deserialize_container.register(DOFArray) -# https://github.com/python/mypy/issues/13040 -def _deserialize_dof_container( # type: ignore[misc] - template, iterable): - def _raise_index_inconsistency(i, stream_i): - raise ValueError( - "out-of-sequence indices supplied in DOFArray deserialization " - f"(expected {i}, received {stream_i})") - - return type(template)( - template.array_context, - data=tuple( - v if i == stream_i else _raise_index_inconsistency(i, stream_i) - for i, (stream_i, v) in enumerate(iterable))) - - -@with_array_context.register(DOFArray) -# https://github.com/python/mypy/issues/13040 -def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: ignore[misc] - return type(ary)(actx, ary.data) - -# }}} - - -# {{{ nested containers - -@with_container_arithmetic(bcasts_across_obj_array=False, - eq_comparison=False, rel_comparison=False, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -@dataclass_array_container -@dataclass(frozen=True) -class MyContainer: - name: str - mass: DOFArray | np.ndarray - momentum: np.ndarray - enthalpy: DOFArray | np.ndarray - - __array_ufunc__ = None - - @property - def array_context(self): - if isinstance(self.mass, np.ndarray): - return next(iter(self.mass)).array_context - else: - return self.mass.array_context - - -@with_container_arithmetic( - bcasts_across_obj_array=False, - bcast_container_types=(DOFArray, np.ndarray), - matmul=True, - rel_comparison=True, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -@dataclass_array_container -@dataclass(frozen=True) -class MyContainerDOFBcast: - name: str - mass: DOFArray | np.ndarray - momentum: np.ndarray - enthalpy: DOFArray | np.ndarray - - @property - def array_context(self): - if isinstance(self.mass, np.ndarray): - return next(iter(self.mass)).array_context - else: - return self.mass.array_context - - def _get_test_containers(actx, ambient_dim=2, shapes=50_000): from numbers import Number + + from testlib import DOFArray, MyContainer, MyContainerDOFBcast if isinstance(shapes, Number | tuple): shapes = [shapes] @@ -286,8 +146,6 @@ def _get_test_containers(actx, ambient_dim=2, shapes=50_000): return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs, bcast_dataclass_of_dofs) -# }}} - # {{{ assert_close_to_numpy* @@ -1224,21 +1082,6 @@ def test_norm_ord_none(actx_factory, ndim): # {{{ test_actx_compile helpers -@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) -@dataclass_array_container -@dataclass(frozen=True) -class Velocity2D: - u: ArrayContainer - v: ArrayContainer - array_context: ArrayContext - - -@with_array_context.register(Velocity2D) -# https://github.com/python/mypy/issues/13040 -def _with_actx_velocity_2d(ary, actx): # type: ignore[misc] - return type(ary)(ary.u, ary.v, actx) - - def scale_and_orthogonalize(alpha, vel): from arraycontext import rec_map_array_container actx = vel.array_context @@ -1353,25 +1196,8 @@ def test_container_equality(actx_factory): # {{{ test_no_leaf_array_type_broadcasting -@with_container_arithmetic( - bcasts_across_obj_array=True, - rel_comparison=True, - _cls_has_array_context_attr=True, - _bcast_actx_array_type=False) -@dataclass_array_container -@dataclass(frozen=True) -class Foo: - u: DOFArray - - # prevent numpy arithmetic from taking precedence - __array_ufunc__ = None - - @property - def array_context(self): - return self.u.array_context - - def test_no_leaf_array_type_broadcasting(actx_factory): + from testlib import Foo # test lack of support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() diff --git a/test/test_utils.py b/test/test_utils.py index a6aa271..807d652 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,7 +1,8 @@ """Testing for internal utilities.""" -from __future__ import annotations -from typing import cast +# Do not add +# from __future__ import annotations +# to allow the non-string annotations below to work. __copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees" @@ -26,6 +27,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import logging +from typing import Optional, cast import numpy as np import pytest @@ -55,25 +57,13 @@ def test_dataclass_array_container() -> None: from arraycontext import Array, dataclass_array_container - # {{{ string fields - - @dataclass - class ArrayContainerWithStringTypes: - x: np.ndarray - y: np.ndarray - - with pytest.raises(TypeError, match="String annotation on field 'y'"): - # NOTE: cannot have string annotations in container - dataclass_array_container(ArrayContainerWithStringTypes) - - # }}} - # {{{ optional fields @dataclass class ArrayContainerWithOptional: x: np.ndarray - y: np.ndarray | None + # Deliberately left as Optional to test compatibility. + y: Optional[np.ndarray] # noqa: UP007 with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`) diff --git a/test/testlib.py b/test/testlib.py new file mode 100644 index 0000000..3f08520 --- /dev/null +++ b/test/testlib.py @@ -0,0 +1,216 @@ +from __future__ import annotations + + +__copyright__ = "Copyright (C) 2020-21 University of Illinois Board of Trustees" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from dataclasses import dataclass + +import numpy as np + +from arraycontext import ( + ArrayContainer, + ArrayContext, + dataclass_array_container, + deserialize_container, + serialize_container, + with_array_context, + with_container_arithmetic, +) + + +# Containers live here, because in order for get_annotations to work, they must +# live somewhere importable. +# See https://docs.python.org/3.12/library/inspect.html#inspect.get_annotations + + +# {{{ stand-in DOFArray implementation + +@with_container_arithmetic( + bcasts_across_obj_array=True, + bitwise=True, + rel_comparison=True, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +class DOFArray: + def __init__(self, actx, data): + if not (actx is None or isinstance(actx, ArrayContext)): + raise TypeError("actx must be of type ArrayContext") + + if not isinstance(data, tuple): + raise TypeError("'data' argument must be a tuple") + + self.array_context = actx + self.data = data + + # prevent numpy broadcasting + __array_ufunc__ = None + + def __bool__(self): + if len(self) == 1 and self.data[0].size == 1: + return bool(self.data[0]) + + raise ValueError( + "The truth value of an array with more than one element is " + "ambiguous. Use actx.np.any(x) or actx.np.all(x)") + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def __repr__(self): + return f"DOFArray({self.data!r})" + + @classmethod + def _serialize_init_arrays_code(cls, instance_name): + return {"_": + (f"{instance_name}_i", f"{instance_name}")} + + @classmethod + def _deserialize_init_arrays_code(cls, template_instance_name, args): + (_, arg), = args.items() + # Why tuple([...])? https://stackoverflow.com/a/48592299 + return (f"{template_instance_name}.array_context, tuple([{arg}])") + + @property + def size(self): + return sum(ary.size for ary in self.data) + + @property + def real(self): + return DOFArray(self.array_context, tuple(subary.real for subary in self)) + + @property + def imag(self): + return DOFArray(self.array_context, tuple(subary.imag for subary in self)) + + +@serialize_container.register(DOFArray) +def _serialize_dof_container(ary: DOFArray): + return list(enumerate(ary.data)) + + +@deserialize_container.register(DOFArray) +# https://github.com/python/mypy/issues/13040 +def _deserialize_dof_container( # type: ignore[misc] + template, iterable): + def _raise_index_inconsistency(i, stream_i): + raise ValueError( + "out-of-sequence indices supplied in DOFArray deserialization " + f"(expected {i}, received {stream_i})") + + return type(template)( + template.array_context, + data=tuple( + v if i == stream_i else _raise_index_inconsistency(i, stream_i) + for i, (stream_i, v) in enumerate(iterable))) + + +@with_array_context.register(DOFArray) +# https://github.com/python/mypy/issues/13040 +def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: ignore[misc] + return type(ary)(actx, ary.data) + +# }}} + + +# {{{ nested containers + +@with_container_arithmetic(bcasts_across_obj_array=False, + eq_comparison=False, rel_comparison=False, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +@dataclass_array_container +@dataclass(frozen=True) +class MyContainer: + name: str + mass: DOFArray | np.ndarray + momentum: np.ndarray + enthalpy: DOFArray | np.ndarray + + __array_ufunc__ = None + + @property + def array_context(self): + if isinstance(self.mass, np.ndarray): + return next(iter(self.mass)).array_context + else: + return self.mass.array_context + + +@with_container_arithmetic( + bcasts_across_obj_array=False, + bcast_container_types=(DOFArray, np.ndarray), + matmul=True, + rel_comparison=True, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +@dataclass_array_container +@dataclass(frozen=True) +class MyContainerDOFBcast: + name: str + mass: DOFArray | np.ndarray + momentum: np.ndarray + enthalpy: DOFArray | np.ndarray + + @property + def array_context(self): + if isinstance(self.mass, np.ndarray): + return next(iter(self.mass)).array_context + else: + return self.mass.array_context + +# }}} + + +@with_container_arithmetic( + bcasts_across_obj_array=True, + rel_comparison=True, + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) +@dataclass_array_container +@dataclass(frozen=True) +class Foo: + u: DOFArray + + # prevent numpy arithmetic from taking precedence + __array_ufunc__ = None + + @property + def array_context(self): + return self.u.array_context + + +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) +@dataclass_array_container +@dataclass(frozen=True) +class Velocity2D: + u: ArrayContainer + v: ArrayContainer + array_context: ArrayContext + + +@with_array_context.register(Velocity2D) +# https://github.com/python/mypy/issues/13040 +def _with_actx_velocity_2d(ary, actx): # type: ignore[misc] + return type(ary)(ary.u, ary.v, actx) -- GitLab