diff --git a/pytools/tag.py b/pytools/tag.py index ce05f48137b1a118ac3b0ba8dcf2104886eb08cf..dcc170c3b6b7c0038ea7102d51839fc018db9984 100644 --- a/pytools/tag.py +++ b/pytools/tag.py @@ -7,6 +7,7 @@ Tag Interface .. autoclass:: Taggable .. autoclass:: Tag .. autoclass:: UniqueTag +.. autoclass:: IgnoredForEqualityTag Supporting Functionality ------------------------ @@ -243,6 +244,7 @@ class Taggable: .. automethod:: tagged .. automethod:: without_tags .. automethod:: tags_of_type + .. automethod:: tags_not_of_type .. versionadded:: 2021.1 """ @@ -321,6 +323,44 @@ class Taggable: for tag in self.tags if isinstance(tag, tag_t)}) + @memoize_method + def tags_not_of_type(self, tag_t: Type[TagT]) -> FrozenSet[Tag]: + """ + Returns *self*'s tags that are not of type *tag_t*. + """ + return frozenset({tag + for tag in self.tags + if not isinstance(tag, tag_t)}) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Taggable): + return (self.tags_not_of_type(IgnoredForEqualityTag) + == other.tags_not_of_type(IgnoredForEqualityTag)) + else: + return super().__eq__(other) + + def __hash__(self) -> int: + return hash(self.tags_not_of_type(IgnoredForEqualityTag)) + + +# }}} + + +# {{{ IgnoredForEqualityTag + +class IgnoredForEqualityTag(Tag): + """ + A superclass for tags that are ignored when testing equality of instances of + :class:`Taggable`. + + When testing equality of two instances of :class:`Taggable`, the equality + of the ``tags`` of both instances is tested after removing all + instances of :class:`IgnoredForEqualityTag`. Instances of + :class:`IgnoredForEqualityTag` are removed for hashing instances of + :class:`Taggable`. + """ + pass + # }}} diff --git a/test/test_pytools.py b/test/test_pytools.py index d53394269c12085a19a27e51c61c1d635efddcb1..a302d1babcdde27c5703c548d4e9d5e7018b50cb 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -26,6 +26,7 @@ import pytest import logging logger = logging.getLogger(__name__) +from typing import FrozenSet @pytest.mark.skipif("sys.version_info < (2, 5)") @@ -558,7 +559,7 @@ def test_tag(): blue_ribbon, red_ribbon)) # Test that tagged() fails if a UniqueTag of the same subclass - # is alredy present + # is already present with pytest.raises(NonUniqueTagError): t1.tagged(best_in_show_ribbon) @@ -664,6 +665,40 @@ def test_unique_name_gen_conflicting_ok(): ung.add_names({"a", "b", "c"}, conflicting_ok=True) +def test_ignoredforequalitytag(): + from pytools.tag import IgnoredForEqualityTag, Tag, Taggable + + # Need a subclass that defines _with_new_tags in order to test. + class TaggableWithNewTags(Taggable): + + def _with_new_tags(self, tags: FrozenSet[Tag]): + return TaggableWithNewTags(tags) + + class Eq1(IgnoredForEqualityTag): + pass + + class Eq2(IgnoredForEqualityTag): + pass + + class Eq3(Tag): + pass + + eq1 = TaggableWithNewTags(frozenset([Eq1()])) + eq2 = TaggableWithNewTags(frozenset([Eq2()])) + eq12 = TaggableWithNewTags(frozenset([Eq1(), Eq2()])) + eq3 = TaggableWithNewTags(frozenset([Eq1(), Eq3()])) + + assert eq1 == eq2 == eq12 + assert eq1 != eq3 + + assert eq1.without_tags(Eq1()) + with pytest.raises(ValueError): + eq3.without_tags(Eq2()) + + assert hash(eq1) == hash(eq2) == hash(eq12) + assert hash(eq1) != hash(eq3) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])