diff --git a/pytools/tag.py b/pytools/tag.py index 0eff8f4fe56e04c8ceb0f8887406ecc9c6c43497..b463ff443d75a99f93ec302b373222c750a051ac 100644 --- a/pytools/tag.py +++ b/pytools/tag.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Tuple, Any, FrozenSet, Union, Iterable, TypeVar +from pytools import memoize __copyright__ = """ Copyright (C) 2020 Andreas Klöckner @@ -28,19 +29,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from pytools import memoize - - -@memoize -def _immediate_unique_tag_descendants(cls): - if UniqueTag in cls.__bases__: - return frozenset([cls]) - else: - result = frozenset() - for base in cls.__bases__: - result = result | _immediate_unique_tag_descendants(base) - return result - # {{{ docs @@ -155,6 +143,17 @@ T_co = TypeVar("T_co", bound="Taggable") # {{{ taggable +@memoize +def _immediate_unique_tag_descendants(cls): + if UniqueTag in cls.__bases__: + return frozenset([cls]) + else: + result = frozenset() + for base in cls.__bases__: + result = result | _immediate_unique_tag_descendants(base) + return result + + class Taggable: """ Parent class for objects with a `tags` attribute. @@ -188,25 +187,6 @@ class Taggable: return t def _check_uniqueness(self): - class_set = set() - - def _check_recursively(in_class, class_set): - - # Termination condition - if issubclass(in_class, UniqueTag) and in_class is not UniqueTag: - if in_class in class_set: - error_string = ("Two or more Tags are instances of {}." - " A Taggable object can only instantiate with one" - " instance of each UniqueTag sub-class.").format( - in_class.__name__) - raise ValueError(error_string) - else: - class_set.add(in_class) - - # Recurse to all superclasses - for c in in_class.__bases__: - _check_recursively(c, class_set) - unique_tag_descendants = set() for tag in self.tags: tag_unique_tag_descendants = _immediate_unique_tag_descendants( @@ -216,6 +196,8 @@ class Taggable: raise ValueError("Multiple tags are direct subclasses of " "the following UniqueTag(s): " f"{', '.join(d.__name__ for d in intersection)}") + else: + unique_tag_descendants.update(tag_unique_tag_descendants) def copy(self: T_co, **kwargs: Any) -> T_co: """