diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 4836d7f3775de4b321803130c3868f5bcd738bd5..4daeeb9d9f75633c6188b42fa743277b889cbadd 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -26,6 +26,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from enum import Enum import logging import hashlib import collections.abc as abc @@ -233,7 +234,8 @@ class KeyBuilder: digest = inner_key_hash.digest() if digest is None: - tname = type(key).__name__ + tp = type(key) + tname = tp.__name__ method = None try: method = getattr(self, "update_for_"+tname) @@ -245,6 +247,9 @@ class KeyBuilder: if isinstance(key, np.dtype): method = self.update_for_specific_dtype + elif issubclass(tp, Enum): + method = self.update_for_enum + if method is not None: inner_key_hash = self.new_hash() method(inner_key_hash, key) @@ -287,6 +292,10 @@ class KeyBuilder: except OverflowError: sz *= 2 + @classmethod + def update_for_enum(cls, key_hash, key): + cls.update_for_str(key_hash, str(key)) + @staticmethod def update_for_bool(key_hash, key): key_hash.update(str(key).encode("utf8"))