From f2b2a0478eba0143325867b2a8f87fcc7c6cad90 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 22 Apr 2021 16:52:35 -0500
Subject: [PATCH] Un-methodize check_tag_uniqueness and normalize_tags

---
 pytools/tag.py | 83 +++++++++++++++++++++++++++++---------------------
 1 file changed, 49 insertions(+), 34 deletions(-)

diff --git a/pytools/tag.py b/pytools/tag.py
index ca541ad..8bc8559 100644
--- a/pytools/tag.py
+++ b/pytools/tag.py
@@ -1,5 +1,5 @@
 from dataclasses import dataclass
-from typing import Tuple, Any, FrozenSet, Union, Iterable, TypeVar
+from typing import Tuple, Set, Any, FrozenSet, Union, Iterable, TypeVar
 from pytools import memoize
 
 __copyright__ = """
@@ -37,6 +37,11 @@ __doc__ = """
 Tag Interface
 ---------------
 
+.. comment::
+
+    ``normalize_tags`` undocumented for now. (Not ready to commit.)
+
+.. autofunction:: check_tag_uniqueness
 .. autoclass:: Taggable
 .. autoclass:: Tag
 .. autoclass:: UniqueTag
@@ -157,7 +162,7 @@ TagOrIterableType = Union[Iterable[Tag], Tag, None]
 T_co = TypeVar("T_co", bound="Taggable")
 
 
-# {{{ taggable
+# {{{ UniqueTag rules checking
 
 @memoize
 def _immediate_unique_tag_descendants(cls):
@@ -179,6 +184,45 @@ class NonUniqueTagError(ValueError):
     pass
 
 
+def check_tag_uniqueness(tags: TagsType):
+    """Ensure that *tags* obeys the rules set forth in :class:`UniqueTag`.
+    If not, raise :exc:`NonUniqueTagError`.
+
+    :returns: *tags*
+    """
+    unique_tag_descendants: Set[Tag] = set()
+    for tag in tags:
+        tag_unique_tag_descendants = _immediate_unique_tag_descendants(
+                type(tag))
+        intersection = unique_tag_descendants & tag_unique_tag_descendants
+        if intersection:
+            raise NonUniqueTagError("Multiple tags are direct subclasses of "
+                    "the following UniqueTag(s): "
+                    f"{', '.join(d.__name__ for d in intersection)}")
+        else:
+            unique_tag_descendants.update(tag_unique_tag_descendants)
+
+    return tags
+
+# }}}
+
+
+def normalize_tags(tags: TagOrIterableType) -> TagsType:
+    if isinstance(tags, Tag):
+        tags = frozenset([tags])
+    elif tags is None:
+        tags = frozenset()
+    else:
+        tags = frozenset(tags)
+
+    for tag in tags:
+        if not isinstance(tag, Tag):
+            raise TypeError(f"'{tag}' is not an instance of pytools.tag.Tag")
+    return tags
+
+
+# {{{ taggable
+
 class Taggable:
     """
     Parent class for objects with a `tags` attribute.
@@ -195,34 +239,7 @@ class Taggable:
     """
 
     def __init__(self, tags: TagsType = frozenset()):
-        # For performance we assert rather than
-        # normalize the input.
-        assert isinstance(tags, FrozenSet)
-        assert all(isinstance(tag, Tag) for tag in tags)
         self.tags = tags
-        self._check_uniqueness()
-
-    def _normalize_tags(self, tags: TagOrIterableType) -> TagsType:
-        if isinstance(tags, Tag):
-            t = frozenset([tags])
-        elif tags is None:
-            t = frozenset()
-        else:
-            t = frozenset(tags)
-        return t
-
-    def _check_uniqueness(self):
-        unique_tag_descendants = set()
-        for tag in self.tags:
-            tag_unique_tag_descendants = _immediate_unique_tag_descendants(
-                    type(tag))
-            intersection = unique_tag_descendants & tag_unique_tag_descendants
-            if intersection:
-                raise NonUniqueTagError("Multiple tags are direct subclasses of "
-                        "the following UniqueTag(s): "
-                        f"{', '.join(d.__name__ for d in intersection)}")
-            else:
-                unique_tag_descendants.update(tag_unique_tag_descendants)
 
     def copy(self: T_co, **kwargs: Any) -> T_co:
         """
@@ -241,10 +258,8 @@ class Taggable:
         :arg tags: An instance of :class:`Tag` or
         an iterable with instances therein.
         """
-        new_tags = self._normalize_tags(tags)
-        union_tags = self.tags | new_tags
-        cpy = self.copy(tags=union_tags)
-        return cpy
+        return self.copy(
+                tags=check_tag_uniqueness(normalize_tags(tags) | self.tags))
 
     def without_tags(self: T_co,
             tags: TagOrIterableType, verify_existence: bool = True) -> T_co:
@@ -259,7 +274,7 @@ class Taggable:
         for removal are present in the original set of tags. Default `True`
         """
 
-        to_remove = self._normalize_tags(tags)
+        to_remove = normalize_tags(tags)
         new_tags = self.tags - to_remove
         if verify_existence and len(new_tags) > len(self.tags) - len(to_remove):
             raise ValueError("A tag specified for removal was not present.")
-- 
GitLab