diff --git a/pytools/graph.py b/pytools/graph.py index c4c17956f0971d87f441f275c5874a7cea4b263f..f500805997a8f5593b96f8a699a8fa2cdf3610a1 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -64,11 +64,13 @@ Type Variables Used is included as a key in the graph. """ +from dataclasses import dataclass from typing import ( Any, Callable, Collection, Dict, + Generic, Hashable, Iterator, List, @@ -254,7 +256,8 @@ class CycleError(Exception): self.node = node -class HeapEntry: +@dataclass(frozen=True) +class HeapEntry(Generic[NodeT]): """ Helper class to compare associated keys while comparing the elements in heap operations. @@ -262,9 +265,8 @@ class HeapEntry: Only needs to define :func:`pytools.graph.__lt__` according to <https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>. """ - def __init__(self, node: NodeT, key: Any) -> None: - self.node = node - self.key = key + node: NodeT + key: Any def __lt__(self, other: HeapEntry) -> bool: return self.key < other.key