From d88292e0029cc33487bbabd49d9b5de78456c6b7 Mon Sep 17 00:00:00 2001
From: Nicholas Christensen <njchris2@illinois.edu>
Date: Thu, 24 Dec 2020 21:34:52 -0600
Subject: [PATCH] update tag set

---
 pytools/tag.py | 46 ++++++++++++++--------------------------------
 1 file changed, 14 insertions(+), 32 deletions(-)

diff --git a/pytools/tag.py b/pytools/tag.py
index 0eff8f4..b463ff4 100644
--- a/pytools/tag.py
+++ b/pytools/tag.py
@@ -1,5 +1,6 @@
 from dataclasses import dataclass
 from typing import Tuple, Any, FrozenSet, Union, Iterable, TypeVar
+from pytools import memoize
 
 __copyright__ = """
 Copyright (C) 2020 Andreas Klöckner
@@ -28,19 +29,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from pytools import memoize
-
-
-@memoize
-def _immediate_unique_tag_descendants(cls):
-    if UniqueTag in cls.__bases__:
-        return frozenset([cls])
-    else:
-        result = frozenset()
-        for base in cls.__bases__:
-            result = result | _immediate_unique_tag_descendants(base)
-        return result
-
 
 # {{{ docs
 
@@ -155,6 +143,17 @@ T_co = TypeVar("T_co", bound="Taggable")
 
 # {{{ taggable
 
+@memoize
+def _immediate_unique_tag_descendants(cls):
+    if UniqueTag in cls.__bases__:
+        return frozenset([cls])
+    else:
+        result = frozenset()
+        for base in cls.__bases__:
+            result = result | _immediate_unique_tag_descendants(base)
+        return result
+
+
 class Taggable:
     """
     Parent class for objects with a `tags` attribute.
@@ -188,25 +187,6 @@ class Taggable:
         return t
 
     def _check_uniqueness(self):
-        class_set = set()
-
-        def _check_recursively(in_class, class_set):
-
-            # Termination condition
-            if issubclass(in_class, UniqueTag) and in_class is not UniqueTag:
-                if in_class in class_set:
-                    error_string = ("Two or more Tags are instances of {}."
-                    " A Taggable object can only instantiate with one"
-                    " instance of each UniqueTag sub-class.").format(
-                        in_class.__name__)
-                    raise ValueError(error_string)
-                else:
-                    class_set.add(in_class)
-
-                # Recurse to all superclasses
-                for c in in_class.__bases__:
-                    _check_recursively(c, class_set)
-
         unique_tag_descendants = set()
         for tag in self.tags:
             tag_unique_tag_descendants = _immediate_unique_tag_descendants(
@@ -216,6 +196,8 @@ class Taggable:
                 raise ValueError("Multiple tags are direct subclasses of "
                         "the following UniqueTag(s): "
                         f"{', '.join(d.__name__ for d in intersection)}")
+            else:
+                unique_tag_descendants.update(tag_unique_tag_descendants)
 
     def copy(self: T_co, **kwargs: Any) -> T_co:
         """
-- 
GitLab