diff --git a/pytools/tag.py b/pytools/tag.py index e9a4cc69f766d0379bdc01334f2f5d40c30c7cc0..b943d954fc48fcb03b0d78bd7d291d566c8d78ef 100644 --- a/pytools/tag.py +++ b/pytools/tag.py @@ -1,10 +1,12 @@ from dataclasses import dataclass -from typing import Tuple, Any, FrozenSet +from typing import Tuple, Any, FrozenSet, Union, Iterable, TypeVar +from pytools import memoize __copyright__ = """ -Copyright (C) 2020 Andreas Kloeckner +Copyright (C) 2020 Andreas Klöckner Copyright (C) 2020 Matt Wala Copyright (C) 2020 Xiaoyu Wei +Copyright (C) 2020 Nicholas Christensen """ __license__ = """ @@ -27,6 +29,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + # {{{ docs __doc__ = """ @@ -34,6 +37,7 @@ __doc__ = """ Tag Interface --------------- +.. autoclass:: Taggable .. autoclass:: Tag .. autoclass:: UniqueTag @@ -41,14 +45,14 @@ Supporting Functionality ------------------------ .. autoclass:: DottedName +.. autoclass:: NonUniqueTagError """ -# }}} - +# }}} -# {{{ dotted name +# {{{ dotted name class DottedName: """ @@ -89,6 +93,7 @@ class DottedName: # }}} + # {{{ tag tag_dataclass = dataclass(init=True, eq=True, frozen=True, repr=True) @@ -117,16 +122,138 @@ class Tag: def tag_name(self) -> DottedName: return DottedName.from_class(type(self)) +# }}} + + +# {{{ unique tag class UniqueTag(Tag): """ - Only one instance of this type of tag may be assigned - to a single tagged object. + A superclass for tags that are unique on each :class:`Taggable`. + + Each instance of :class:`Taggable` may have no more than one + instance of each subclass of :class:`UniqueTag` in its + set of `tags`. Multiple `UniqueTag` instances of + different (immediate) subclasses are allowed. """ pass +# }}} + TagsType = FrozenSet[Tag] +TagOrIterableType = Union[Iterable[Tag], Tag, None] +T_co = TypeVar("T_co", bound="Taggable") + + +# {{{ taggable + +@memoize +def _immediate_unique_tag_descendants(cls): + if UniqueTag in cls.__bases__: + return frozenset([cls]) + else: + result = frozenset() + for base in cls.__bases__: + result = result | _immediate_unique_tag_descendants(base) + return result + + +class NonUniqueTagError(ValueError): + """ + Raised when a :class:`Taggable` object is instantiated with more + than one :class:`UniqueTag` instances of the same subclass in + its set of tags. + """ + pass + + +class Taggable: + """ + Parent class for objects with a `tags` attribute. + + .. attribute:: tags + + A :class:`frozenset` of :class:`Tag` instances + + .. method:: copy + .. method:: tagged + .. method:: without_tags + + .. versionadded:: 2021.1 + """ + + def __init__(self, tags: TagsType = frozenset()): + # For performance we assert rather than + # normalize the input. + assert isinstance(tags, FrozenSet) + assert all(isinstance(tag, Tag) for tag in tags) + self.tags = tags + self._check_uniqueness() + + def _normalize_tags(self, tags: TagOrIterableType) -> TagsType: + if isinstance(tags, Tag): + t = frozenset([tags]) + elif tags is None: + t = frozenset() + else: + t = frozenset(tags) + return t + + def _check_uniqueness(self): + unique_tag_descendants = set() + for tag in self.tags: + tag_unique_tag_descendants = _immediate_unique_tag_descendants( + type(tag)) + intersection = unique_tag_descendants & tag_unique_tag_descendants + if intersection: + raise NonUniqueTagError("Multiple tags are direct subclasses of " + "the following UniqueTag(s): " + f"{', '.join(d.__name__ for d in intersection)}") + else: + unique_tag_descendants.update(tag_unique_tag_descendants) + + def copy(self: T_co, **kwargs: Any) -> T_co: + """ + Returns of copy of *self* with the specified tags. This method + should be overridden by subclasses. + """ + raise NotImplementedError("The copy function is not implemented.") + + def tagged(self: T_co, tags: TagOrIterableType) -> T_co: + """ + Return a copy of *self* with the specified + tag or tags unioned. If *tags* is a :class:`pytools.tag.UniqueTag` + and other tags of this type are already present, an error is raised + Assumes `self.copy(tags=)` is implemented. + + :arg tags: An instance of :class:`Tag` or + an iterable with instances therein. + """ + new_tags = self._normalize_tags(tags) + union_tags = self.tags | new_tags + cpy = self.copy(tags=union_tags) + return cpy + + def without_tags(self: T_co, + tags: TagOrIterableType, verify_existence: bool = True) -> T_co: + """ + Return a copy of *self* without the specified tags. + `self.copy(tags=)` is implemented. + + :arg tags: An instance of :class:`Tag` or an iterable with instances + therein. + :arg verify_existence: If set + to `True`, this method raises an exception if not all tags specified + for removal are present in the original set of tags. Default `True` + """ + + to_remove = self._normalize_tags(tags) + new_tags = self.tags - to_remove + if verify_existence and len(new_tags) > len(self.tags) - len(to_remove): + raise ValueError("A tag specified for removal was not present.") + + return self.copy(tags=new_tags) # }}} diff --git a/pytools/version.py b/pytools/version.py index 99b6b21563e5660d93da0deb423168b95d50af55..e4c47bc9ed6a54934f00ad2e03256e2fa19c60f7 100644 --- a/pytools/version.py +++ b/pytools/version.py @@ -1,3 +1,3 @@ -VERSION = (2020, 4, 4) +VERSION = (2021, 1) VERSION_STATUS = "" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS diff --git a/test/test_pytools.py b/test/test_pytools.py index 0e8f99205afdeb98a76d53a2ddff52aaae3e4413..87b5d3a76d2d76e4a61e8343252e9b81cd7c767c 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -283,6 +283,93 @@ def test_make_obj_array_iteration(): # }}} +def test_tag(): + from pytools.tag import Taggable, Tag, UniqueTag, NonUniqueTagError + + # Need a subclass that defines the copy function in order to test. + class TaggableWithCopy(Taggable): + + def copy(self, **kwargs): + return TaggableWithCopy(kwargs["tags"]) + + class FairRibbon(Tag): + pass + + class BlueRibbon(FairRibbon): + pass + + class RedRibbon(FairRibbon): + pass + + class ShowRibbon(FairRibbon, UniqueTag): + pass + + class BestInShowRibbon(ShowRibbon): + pass + + class ReserveBestInShowRibbon(ShowRibbon): + pass + + class BestInClassRibbon(FairRibbon, UniqueTag): + pass + + best_in_show_ribbon = BestInShowRibbon() + reserve_best_in_show_ribbon = ReserveBestInShowRibbon() + blue_ribbon = BlueRibbon() + red_ribbon = RedRibbon() + best_in_class_ribbon = BestInClassRibbon() + + # Test that instantiation fails if tags is not a FrozenSet of Tags + with pytest.raises(AssertionError): + TaggableWithCopy(tags=[best_in_show_ribbon, reserve_best_in_show_ribbon, + blue_ribbon, red_ribbon]) + + # Test that instantiation fails if tags is not a FrozenSet of Tags + with pytest.raises(AssertionError): + TaggableWithCopy(tags=frozenset((1, reserve_best_in_show_ribbon, blue_ribbon, + red_ribbon))) + + # Test that instantiation fails if there are multiple instances + # of the same UniqueTag subclass + with pytest.raises(NonUniqueTagError): + TaggableWithCopy(tags=frozenset((best_in_show_ribbon, + reserve_best_in_show_ribbon, blue_ribbon, red_ribbon))) + + # Test that instantiation succeeds if there are multiple instances + # Tag subclasses. + t1 = TaggableWithCopy(frozenset([reserve_best_in_show_ribbon, blue_ribbon, + red_ribbon])) + assert t1.tags == frozenset((reserve_best_in_show_ribbon, red_ribbon, + blue_ribbon)) + + # Test that instantiation succeeds if there are multiple instances + # of UniqueTag of different subclasses. + t1 = TaggableWithCopy(frozenset([reserve_best_in_show_ribbon, + best_in_class_ribbon, blue_ribbon, + blue_ribbon])) + assert t1.tags == frozenset((reserve_best_in_show_ribbon, best_in_class_ribbon, + blue_ribbon)) + + # Test tagged() function + t2 = t1.tagged(red_ribbon) + print(t2.tags) + assert t2.tags == frozenset((reserve_best_in_show_ribbon, best_in_class_ribbon, + blue_ribbon, red_ribbon)) + + # Test that tagged() fails if a UniqueTag of the same subclass + # is alredy present + with pytest.raises(NonUniqueTagError): + t1.tagged(best_in_show_ribbon) + + # Test without_tags() function + t4 = t2.without_tags(red_ribbon) + assert t4.tags == t1.tags + + # Test that without_tags() fails if the tag is not present. + with pytest.raises(ValueError): + t4.without_tags(red_ribbon) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])