diff --git a/pytools/tag.py b/pytools/tag.py index ca541ad65cc8bdbbcae860cd5030d0a1c2ec8ac8..42d9e50f33f4e4a8868733c8ace15b63f0c67c5e 100644 --- a/pytools/tag.py +++ b/pytools/tag.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Tuple, Any, FrozenSet, Union, Iterable, TypeVar +from typing import Tuple, Set, Any, FrozenSet, Union, Iterable, TypeVar from pytools import memoize __copyright__ = """ @@ -36,7 +36,9 @@ __doc__ = """ Tag Interface --------------- +.. ``normalize_tags`` undocumented for now. (Not ready to commit.) +.. autofunction:: check_tag_uniqueness .. autoclass:: Taggable .. autoclass:: Tag .. autoclass:: UniqueTag @@ -47,6 +49,13 @@ Supporting Functionality .. autoclass:: DottedName .. autoclass:: NonUniqueTagError + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: T_co + + A covariant type variable used in, e.g. :class:`Taggable.copy`. """ # }}} @@ -157,7 +166,7 @@ TagOrIterableType = Union[Iterable[Tag], Tag, None] T_co = TypeVar("T_co", bound="Taggable") -# {{{ taggable +# {{{ UniqueTag rules checking @memoize def _immediate_unique_tag_descendants(cls): @@ -179,6 +188,44 @@ class NonUniqueTagError(ValueError): pass +def check_tag_uniqueness(tags: TagsType): + """Ensure that *tags* obeys the rules set forth in :class:`UniqueTag`. + If not, raise :exc:`NonUniqueTagError`. If any *tags* are not + subclasses of :class:`Tag`, a :exc:`TypeError` will be raised. + + :returns: *tags* + """ + unique_tag_descendants: Set[Tag] = set() + for tag in tags: + if not isinstance(tag, Tag): + raise TypeError(f"'{tag}' is not an instance of pytools.tag.Tag") + tag_unique_tag_descendants = _immediate_unique_tag_descendants( + type(tag)) + intersection = unique_tag_descendants & tag_unique_tag_descendants + if intersection: + raise NonUniqueTagError("Multiple tags are direct subclasses of " + "the following UniqueTag(s): " + f"{', '.join(d.__name__ for d in intersection)}") + else: + unique_tag_descendants.update(tag_unique_tag_descendants) + + return tags + +# }}} + + +def normalize_tags(tags: TagOrIterableType) -> TagsType: + if isinstance(tags, Tag): + tags = frozenset([tags]) + elif tags is None: + tags = frozenset() + else: + tags = frozenset(tags) + return tags + + +# {{{ taggable + class Taggable: """ Parent class for objects with a `tags` attribute. @@ -187,42 +234,30 @@ class Taggable: A :class:`frozenset` of :class:`Tag` instances - .. method:: copy - .. method:: tagged - .. method:: without_tags + .. automethod:: __init__ + + .. automethod:: copy + .. automethod:: tagged + .. automethod:: without_tags .. versionadded:: 2021.1 """ + # ReST references in docstrings must be fully qualified, as docstrings may + # be inherited and appear in different contexts. + def __init__(self, tags: TagsType = frozenset()): - # For performance we assert rather than - # normalize the input. - assert isinstance(tags, FrozenSet) - assert all(isinstance(tag, Tag) for tag in tags) + """ + Constructor for all objects that possess a `tags` attribute. + + :arg tags: a :class:`frozenset` of :class:`~pytools.tag.Tag` objects. + Tags can be modified via the :meth:`~pytools.tag.Taggable.tagged` and + :meth:`~pytools.tag.Taggable.without_tags` routines. Input checking + of *tags* should be performed before creating a + :class:`~pytools.tag.Taggable` instance, using + :func:`~pytools.tag.check_tag_uniqueness`. + """ self.tags = tags - self._check_uniqueness() - - def _normalize_tags(self, tags: TagOrIterableType) -> TagsType: - if isinstance(tags, Tag): - t = frozenset([tags]) - elif tags is None: - t = frozenset() - else: - t = frozenset(tags) - return t - - def _check_uniqueness(self): - unique_tag_descendants = set() - for tag in self.tags: - tag_unique_tag_descendants = _immediate_unique_tag_descendants( - type(tag)) - intersection = unique_tag_descendants & tag_unique_tag_descendants - if intersection: - raise NonUniqueTagError("Multiple tags are direct subclasses of " - "the following UniqueTag(s): " - f"{', '.join(d.__name__ for d in intersection)}") - else: - unique_tag_descendants.update(tag_unique_tag_descendants) def copy(self: T_co, **kwargs: Any) -> T_co: """ @@ -239,12 +274,10 @@ class Taggable: Assumes `self.copy(tags=)` is implemented. :arg tags: An instance of :class:`Tag` or - an iterable with instances therein. + an iterable with instances therein. """ - new_tags = self._normalize_tags(tags) - union_tags = self.tags | new_tags - cpy = self.copy(tags=union_tags) - return cpy + return self.copy( + tags=check_tag_uniqueness(normalize_tags(tags) | self.tags)) def without_tags(self: T_co, tags: TagOrIterableType, verify_existence: bool = True) -> T_co: @@ -253,18 +286,18 @@ class Taggable: `self.copy(tags=)` is implemented. :arg tags: An instance of :class:`Tag` or an iterable with instances - therein. - :arg verify_existence: If set - to `True`, this method raises an exception if not all tags specified - for removal are present in the original set of tags. Default `True` + therein. + :arg verify_existence: If set to `True`, this method raises + an exception if not all tags specified for removal are + present in the original set of tags. Default `True`. """ - to_remove = self._normalize_tags(tags) + to_remove = normalize_tags(tags) new_tags = self.tags - to_remove if verify_existence and len(new_tags) > len(self.tags) - len(to_remove): raise ValueError("A tag specified for removal was not present.") - return self.copy(tags=new_tags) + return self.copy(tags=check_tag_uniqueness(new_tags)) # }}} diff --git a/pytools/version.py b/pytools/version.py index 6022c9e3dc967dca5576bfb741d57c0df44a8e22..2d219656d699ae58fe1b9167cf43f9245d0109ed 100644 --- a/pytools/version.py +++ b/pytools/version.py @@ -1,3 +1,3 @@ -VERSION = (2021, 2, 3) +VERSION = (2021, 2, 4) VERSION_STATUS = "" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS diff --git a/test/test_pytools.py b/test/test_pytools.py index 60a7f712b37b1f745511802d7a1d16df1d53bd0d..6c24ecc501c80dde26566a48ad86bb8924a25dc5 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -443,7 +443,9 @@ def test_obj_array_vectorize(c=1): def test_tag(): - from pytools.tag import Taggable, Tag, UniqueTag, NonUniqueTagError + from pytools.tag import ( + Taggable, Tag, UniqueTag, NonUniqueTagError, check_tag_uniqueness + ) # Need a subclass that defines the copy function in order to test. class TaggableWithCopy(Taggable): @@ -478,21 +480,19 @@ def test_tag(): red_ribbon = RedRibbon() best_in_class_ribbon = BestInClassRibbon() - # Test that instantiation fails if tags is not a FrozenSet of Tags - with pytest.raises(AssertionError): - TaggableWithCopy(tags=[best_in_show_ribbon, reserve_best_in_show_ribbon, - blue_ribbon, red_ribbon]) - - # Test that instantiation fails if tags is not a FrozenSet of Tags - with pytest.raises(AssertionError): - TaggableWithCopy(tags=frozenset((1, reserve_best_in_show_ribbon, blue_ribbon, - red_ribbon))) - - # Test that instantiation fails if there are multiple instances + # Test that input processing fails if there are multiple instances # of the same UniqueTag subclass with pytest.raises(NonUniqueTagError): - TaggableWithCopy(tags=frozenset((best_in_show_ribbon, - reserve_best_in_show_ribbon, blue_ribbon, red_ribbon))) + check_tag_uniqueness(frozenset(( + best_in_show_ribbon, + reserve_best_in_show_ribbon, blue_ribbon, red_ribbon))) + + # Test that input processing fails if any of the tags are not + # a subclass of Tag + with pytest.raises(TypeError): + check_tag_uniqueness(frozenset(( + "I am not a tag", best_in_show_ribbon, + blue_ribbon, red_ribbon))) # Test that instantiation succeeds if there are multiple instances # Tag subclasses. @@ -520,6 +520,10 @@ def test_tag(): with pytest.raises(NonUniqueTagError): t1.tagged(best_in_show_ribbon) + # Test that tagged() fails if tags are not a FrozenSet of Tags + with pytest.raises(TypeError): + t1.tagged(tags=frozenset((1,))) + # Test without_tags() function t4 = t2.without_tags(red_ribbon) assert t4.tags == t1.tags