diff --git a/test/test_pytools.py b/test/test_pytools.py index 60a7f712b37b1f745511802d7a1d16df1d53bd0d..6c24ecc501c80dde26566a48ad86bb8924a25dc5 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -443,7 +443,9 @@ def test_obj_array_vectorize(c=1): def test_tag(): - from pytools.tag import Taggable, Tag, UniqueTag, NonUniqueTagError + from pytools.tag import ( + Taggable, Tag, UniqueTag, NonUniqueTagError, check_tag_uniqueness + ) # Need a subclass that defines the copy function in order to test. class TaggableWithCopy(Taggable): @@ -478,21 +480,19 @@ def test_tag(): red_ribbon = RedRibbon() best_in_class_ribbon = BestInClassRibbon() - # Test that instantiation fails if tags is not a FrozenSet of Tags - with pytest.raises(AssertionError): - TaggableWithCopy(tags=[best_in_show_ribbon, reserve_best_in_show_ribbon, - blue_ribbon, red_ribbon]) - - # Test that instantiation fails if tags is not a FrozenSet of Tags - with pytest.raises(AssertionError): - TaggableWithCopy(tags=frozenset((1, reserve_best_in_show_ribbon, blue_ribbon, - red_ribbon))) - - # Test that instantiation fails if there are multiple instances + # Test that input processing fails if there are multiple instances # of the same UniqueTag subclass with pytest.raises(NonUniqueTagError): - TaggableWithCopy(tags=frozenset((best_in_show_ribbon, - reserve_best_in_show_ribbon, blue_ribbon, red_ribbon))) + check_tag_uniqueness(frozenset(( + best_in_show_ribbon, + reserve_best_in_show_ribbon, blue_ribbon, red_ribbon))) + + # Test that input processing fails if any of the tags are not + # a subclass of Tag + with pytest.raises(TypeError): + check_tag_uniqueness(frozenset(( + "I am not a tag", best_in_show_ribbon, + blue_ribbon, red_ribbon))) # Test that instantiation succeeds if there are multiple instances # Tag subclasses. @@ -520,6 +520,10 @@ def test_tag(): with pytest.raises(NonUniqueTagError): t1.tagged(best_in_show_ribbon) + # Test that tagged() fails if tags are not a FrozenSet of Tags + with pytest.raises(TypeError): + t1.tagged(tags=frozenset((1,))) + # Test without_tags() function t4 = t2.without_tags(red_ribbon) assert t4.tags == t1.tags