Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
arraycontext
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Andreas Klöckner
arraycontext
Commits
9ac8bc88
Commit
9ac8bc88
authored
3 weeks ago
by
Alexandru Fikl
Committed by
Andreas Klöckner
3 weeks ago
Browse files
Options
Downloads
Patches
Plain Diff
dataclass: refactor evaluating string fields
parent
c4f00b8b
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Pipeline
#628993
failed
4 days ago
Stage: test
Changes
1
Pipelines
4
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
arraycontext/container/dataclass.py
+44
-20
44 additions, 20 deletions
arraycontext/container/dataclass.py
with
44 additions
and
20 deletions
arraycontext/container/dataclass.py
+
44
−
20
View file @
9ac8bc88
...
...
@@ -32,14 +32,22 @@ THE SOFTWARE.
"""
from
collections.abc
import
Mapping
,
Sequence
from
dataclasses
import
Field
,
fields
,
is_dataclass
from
typing
import
Union
,
get_args
,
get_origin
from
dataclasses
import
fields
,
is_dataclass
from
typing
import
NamedTuple
,
Union
,
get_args
,
get_origin
from
arraycontext.container
import
is_array_container_type
# {{{ dataclass containers
class
_Field
(
NamedTuple
):
"""
Small lookalike for :class:`dataclasses.Field`.
"""
init
:
bool
name
:
str
type
:
type
def
is_array_type
(
tp
:
type
)
->
bool
:
from
arraycontext
import
Array
return
tp
is
Array
or
is_array_container_type
(
tp
)
...
...
@@ -73,7 +81,9 @@ def dataclass_array_container(cls: type) -> type:
assert
is_dataclass
(
cls
)
def
is_array_field
(
f
:
Field
,
field_type
:
type
)
->
bool
:
def
is_array_field
(
f
:
_Field
)
->
bool
:
field_type
=
f
.
type
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
...
...
@@ -96,10 +106,8 @@ def dataclass_array_container(cls: type) -> type:
f
"
Field
'
{
f
.
name
}
'
union contains non-array container
"
"
arguments. All arguments must be array containers.
"
)
if
isinstance
(
field_type
,
str
):
raise
TypeError
(
f
"
String annotation on field
'
{
f
.
name
}
'
not supported.
"
"
(this may be due to
'
from __future__ import annotations
'
)
"
)
# NOTE: this should never happen due to using `inspect.get_annotations`
assert
not
isinstance
(
field_type
,
str
)
if
__debug__
:
if
not
f
.
init
:
...
...
@@ -127,36 +135,52 @@ def dataclass_array_container(cls: type) -> type:
return
is_array_type
(
field_type
)
from
pytools
import
partition
array_fields
=
_get_annotated_fields
(
cls
)
array_fields
,
non_array_fields
=
partition
(
is_array_field
,
array_fields
)
if
not
array_fields
:
raise
ValueError
(
f
"'
{
cls
}
'
must have fields with array container type
"
"
in order to use the
'
dataclass_array_container
'
decorator
"
)
return
_inject_dataclass_serialization
(
cls
,
array_fields
,
non_array_fields
)
def
_get_annotated_fields
(
cls
:
type
)
->
Sequence
[
_Field
]:
"""
Get a list of fields in the class *cls* with evaluated types.
If any of the fields in *cls* have type annotations that are strings, e.g.
from using ``from __future__ import annotations``, this function evaluates
them using :func:`inspect.get_annotations`. Note that this requires the class
to live in a module that is importable.
:return: a list of fields.
"""
from
inspect
import
get_annotations
array_fields
:
list
[
Field
]
=
[]
non_array_fields
:
list
[
Field
]
=
[]
result
=
[]
cls_ann
:
Mapping
[
str
,
type
]
|
None
=
None
for
field
in
fields
(
cls
):
field_type_or_str
=
field
.
type
if
isinstance
(
field_type_or_str
,
str
):
if
cls_ann
is
None
:
cls_ann
=
get_annotations
(
cls
,
eval_str
=
True
)
field_type
=
cls_ann
[
field
.
name
]
else
:
field_type
=
field_type_or_str
if
is_array_field
(
field
,
field_type
):
array_fields
.
append
(
field
)
else
:
non_array_fields
.
append
(
field
)
if
not
array_fields
:
raise
ValueError
(
f
"'
{
cls
}
'
must have fields with array container type
"
"
in order to use the
'
dataclass_array_container
'
decorator
"
)
result
.
append
(
_Field
(
init
=
field
.
init
,
name
=
field
.
name
,
type
=
field_type
))
return
_inject_dataclass_serialization
(
cls
,
array_fields
,
non_array_fields
)
return
result
def
_inject_dataclass_serialization
(
cls
:
type
,
array_fields
:
Sequence
[
Field
],
non_array_fields
:
Sequence
[
Field
])
->
type
:
array_fields
:
Sequence
[
_
Field
],
non_array_fields
:
Sequence
[
_
Field
])
->
type
:
"""
Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment