diff --git a/pytools/__init__.py b/pytools/__init__.py index 0f452a804fb8103fb91309c370172b7aa422c3fa..33d5b206a68f1a009ad72a298515feb2b9634d5c 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -37,7 +37,7 @@ from functools import reduce, wraps from sys import intern from typing import ( Any, Callable, ClassVar, Dict, Generic, Hashable, Iterable, Iterator, List, - Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) + Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, cast) try: @@ -411,9 +411,16 @@ class RecordWithoutPickling: """ __slots__: ClassVar[List[str]] = [] - fields: ClassVar[Set[str]] - def __init__(self, valuedict=None, exclude=None, **kwargs): + # A dict, not a set, to maintain a deterministic iteration order + fields: ClassVar[Dict[str, None]] + + def __init__(self, valuedict: Optional[Mapping[str, Any]] = None, + exclude: Optional[Iterable[str]] = None, **kwargs: Any) -> None: + from warnings import warn + warn(f"{self.__class__.__bases__[0]} is deprecated and will be " + "removed in 2025. Use dataclasses instead.") + assert self.__class__ is not Record if exclude is None: @@ -422,17 +429,20 @@ class RecordWithoutPickling: try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} + + if isinstance(fields, set): + self.__class__.fields = fields = dict.fromkeys(sorted(fields)) if valuedict is not None: kwargs.update(valuedict) for key, value in kwargs.items(): if key not in exclude: - fields.add(key) + fields[key] = None setattr(self, key, value) - def get_copy_kwargs(self, **kwargs): + def get_copy_kwargs(self, **kwargs: Any) -> Dict[str, Any]: for f in self.__class__.fields: if f not in kwargs: try: @@ -441,25 +451,25 @@ class RecordWithoutPickling: pass return kwargs - def copy(self, **kwargs): + def copy(self, **kwargs: Any) -> "RecordWithoutPickling": return self.__class__(**self.get_copy_kwargs(**kwargs)) - def __repr__(self): + def __repr__(self) -> str: return "{}({})".format( self.__class__.__name__, ", ".join(f"{fld}={getattr(self, fld)!r}" for fld in self.__class__.fields if hasattr(self, fld))) - def register_fields(self, new_fields): + def register_fields(self, new_fields: Iterable[str]) -> None: try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} - fields.update(new_fields) + fields.update(dict.fromkeys(sorted(new_fields))) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # This method is implemented to avoid pylint 'no-member' errors for # attribute access. raise AttributeError( @@ -470,46 +480,46 @@ class RecordWithoutPickling: class Record(RecordWithoutPickling): __slots__: ClassVar[List[str]] = [] - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { key: getattr(self, key) for key in self.__class__.fields if hasattr(self, key)} - def __setstate__(self, valuedict): + def __setstate__(self, valuedict: Mapping[str, Any]) -> None: try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} + + if isinstance(fields, set): + self.__class__.fields = fields = dict.fromkeys(sorted(fields)) for key, value in valuedict.items(): - fields.add(key) + fields[key] = None setattr(self, key, value) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self is other: return True return (self.__class__ == other.__class__ and self.__getstate__() == other.__getstate__()) - def __ne__(self, other): - return not self.__eq__(other) - class ImmutableRecordWithoutPickling(RecordWithoutPickling): """Hashable record. Does not explicitly enforce immutability.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: RecordWithoutPickling.__init__(self, *args, **kwargs) - self._cached_hash = None + self._cached_hash: Optional[int] = None - def __hash__(self): + def __hash__(self) -> int: # This attribute may vanish during pickling. if getattr(self, "_cached_hash", None) is None: self._cached_hash = hash( (type(self),) + tuple(getattr(self, field) for field in self.__class__.fields)) - return self._cached_hash + return cast(int, self._cached_hash) class ImmutableRecord(ImmutableRecordWithoutPickling, Record): diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index 7e1fe3ace555aed573417bc04b596773b691acd1..294c3eb0610c1b061ce258aa02fffc0cbcf77b30 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -26,6 +26,8 @@ import sys import pytest +from pytools import Record + logger = logging.getLogger(__name__) from typing import FrozenSet @@ -784,6 +786,139 @@ def test_unique(): assert next(unique([]), None) is None +# These classes must be defined globally to be picklable +class SimpleRecord(Record): + pass + + +class SetBasedRecord(Record): + fields = {"c", "b", "a"} # type: ignore[assignment] + + def __init__(self, c, b, a): + super().__init__(c=c, b=b, a=a) + + +def test_record(): + # {{{ New, dict-based Record + + r1 = SimpleRecord(a=1, b=2) + assert r1.a == 1 + assert r1.b == 2 + + r2 = r1.copy() + assert r2.a == 1 + assert r1 == r2 + + r3 = r1.copy(b=3) + assert r3.b == 3 + assert r1 != r3 + + assert str(r1) == str(r2) == "SimpleRecord(a=1, b=2)" + assert str(r3) == "SimpleRecord(a=1, b=3)" + + # Unregistered fields are (silently) ignored for printing + r1.f = 6 + assert str(r1) == "SimpleRecord(a=1, b=2)" + + # Registered fields are printed + r1.register_fields({"d", "e"}) + assert str(r1) == "SimpleRecord(a=1, b=2)" + + r1.d = 4 + r1.e = 5 + assert str(r1) == "SimpleRecord(a=1, b=2, d=4, e=5)" + + with pytest.raises(AttributeError): + r1.ff + + # Test pickling + + import pickle + r1_pickled = pickle.loads(pickle.dumps(r1)) + assert r1 == r1_pickled + + class SimpleRecord2(Record): + pass + + r_new = SimpleRecord2(b=2, a=1) + assert r_new.a == 1 + assert r_new.b == 2 + + assert str(r_new) == "SimpleRecord2(b=2, a=1)" + + assert r_new != r1 + + # }}} + + # {{{ Legacy set-based record (used in Loopy) + + r = SetBasedRecord(3, 2, 1) + + # Fields are converted to a dict during __init__ + assert isinstance(r.fields, dict) + assert r.a == 1 + assert r.b == 2 + assert r.c == 3 + + # Fields are sorted alphabetically in set-based records + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + # Unregistered fields are (silently) ignored for printing + r.f = 6 + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + # Registered fields are printed + r.register_fields({"d", "e"}) + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + r.d = 4 + r.e = 5 + assert str(r) == "SetBasedRecord(a=1, b=2, c=3, d=4, e=5)" + + with pytest.raises(AttributeError): + r.ff + + # Test pickling + r_pickled = pickle.loads(pickle.dumps(r)) + assert r == r_pickled + + # }}} + + # {{{ __slots__, __dict__, __weakref__ handling + + class RecordWithEmptySlots(Record): + __slots__ = [] + + assert hasattr(RecordWithEmptySlots(), "__slots__") + assert not hasattr(RecordWithEmptySlots(), "__dict__") + assert not hasattr(RecordWithEmptySlots(), "__weakref__") + + class RecordWithUnsetSlots(Record): + pass + + assert hasattr(RecordWithUnsetSlots(), "__slots__") + assert hasattr(RecordWithUnsetSlots(), "__dict__") + assert hasattr(RecordWithUnsetSlots(), "__weakref__") + + from pytools import ImmutableRecord + + class ImmutableRecordWithEmptySlots(ImmutableRecord): + __slots__ = [] + + assert hasattr(ImmutableRecordWithEmptySlots(), "__slots__") + assert hasattr(ImmutableRecordWithEmptySlots(), "__dict__") + assert hasattr(ImmutableRecordWithEmptySlots(), "__weakref__") + + class ImmutableRecordWithUnsetSlots(ImmutableRecord): + pass + + assert hasattr(ImmutableRecordWithUnsetSlots(), "__slots__") + assert hasattr(ImmutableRecordWithUnsetSlots(), "__dict__") + assert hasattr(ImmutableRecordWithUnsetSlots(), "__weakref__") + + # }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])