Skip to content
Snippets Groups Projects
Commit 9ac8bc88 authored by Alexandru Fikl's avatar Alexandru Fikl Committed by Andreas Klöckner
Browse files

dataclass: refactor evaluating string fields

parent c4f00b8b
No related branches found
No related tags found
No related merge requests found
Pipeline #628993 failed
......@@ -32,14 +32,22 @@ THE SOFTWARE.
"""
from collections.abc import Mapping, Sequence
from dataclasses import Field, fields, is_dataclass
from typing import Union, get_args, get_origin
from dataclasses import fields, is_dataclass
from typing import NamedTuple, Union, get_args, get_origin
from arraycontext.container import is_array_container_type
# {{{ dataclass containers
class _Field(NamedTuple):
"""Small lookalike for :class:`dataclasses.Field`."""
init: bool
name: str
type: type
def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)
......@@ -73,7 +81,9 @@ def dataclass_array_container(cls: type) -> type:
assert is_dataclass(cls)
def is_array_field(f: Field, field_type: type) -> bool:
def is_array_field(f: _Field) -> bool:
field_type = f.type
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
......@@ -96,10 +106,8 @@ def dataclass_array_container(cls: type) -> type:
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
if isinstance(field_type, str):
raise TypeError(
f"String annotation on field '{f.name}' not supported. "
"(this may be due to 'from __future__ import annotations')")
# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)
if __debug__:
if not f.init:
......@@ -127,36 +135,52 @@ def dataclass_array_container(cls: type) -> type:
return is_array_type(field_type)
from pytools import partition
array_fields = _get_annotated_fields(cls)
array_fields, non_array_fields = partition(is_array_field, array_fields)
if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
def _get_annotated_fields(cls: type) -> Sequence[_Field]:
"""Get a list of fields in the class *cls* with evaluated types.
If any of the fields in *cls* have type annotations that are strings, e.g.
from using ``from __future__ import annotations``, this function evaluates
them using :func:`inspect.get_annotations`. Note that this requires the class
to live in a module that is importable.
:return: a list of fields.
"""
from inspect import get_annotations
array_fields: list[Field] = []
non_array_fields: list[Field] = []
result = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)
field_type = cls_ann[field.name]
else:
field_type = field_type_or_str
if is_array_field(field, field_type):
array_fields.append(field)
else:
non_array_fields.append(field)
if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")
result.append(_Field(init=field.init, name=field.name, type=field_type))
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
return result
def _inject_dataclass_serialization(
cls: type,
array_fields: Sequence[Field],
non_array_fields: Sequence[Field]) -> type:
array_fields: Sequence[_Field],
non_array_fields: Sequence[_Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment