diff --git a/pytools/tag.py b/pytools/tag.py index cc6922ccead2ed61a1e81633ea4fadc0c713bda1..0eff8f4fe56e04c8ceb0f8887406ecc9c6c43497 100644 --- a/pytools/tag.py +++ b/pytools/tag.py @@ -28,6 +28,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from pytools import memoize + + +@memoize +def _immediate_unique_tag_descendants(cls): + if UniqueTag in cls.__bases__: + return frozenset([cls]) + else: + result = frozenset() + for base in cls.__bases__: + result = result | _immediate_unique_tag_descendants(base) + return result + + # {{{ docs __doc__ = """ @@ -90,6 +104,7 @@ class DottedName: # }}} + # {{{ tag tag_dataclass = dataclass(init=True, eq=True, frozen=True, repr=True) @@ -118,6 +133,10 @@ class Tag: def tag_name(self) -> DottedName: return DottedName.from_class(type(self)) +# }}} + + +# {{{ unique tag class UniqueTag(Tag): """ @@ -126,12 +145,16 @@ class UniqueTag(Tag): """ pass +# }}} + TagsType = FrozenSet[Tag] TagOrIterableType = Union[Iterable[Tag], Tag, None] T_co = TypeVar("T_co", bound="Taggable") +# {{{ taggable + class Taggable: """ Parent class for objects with a `tags` attribute. @@ -184,8 +207,15 @@ class Taggable: for c in in_class.__bases__: _check_recursively(c, class_set) + unique_tag_descendants = set() for tag in self.tags: - _check_recursively(type(tag), class_set) + tag_unique_tag_descendants = _immediate_unique_tag_descendants( + type(tag)) + intersection = unique_tag_descendants & tag_unique_tag_descendants + if intersection: + raise ValueError("Multiple tags are direct subclasses of " + "the following UniqueTag(s): " + f"{', '.join(d.__name__ for d in intersection)}") def copy(self: T_co, **kwargs: Any) -> T_co: """