diff --git a/pytools/__init__.py b/pytools/__init__.py index 57e33787e7660dd15729fd633d1d892e21946825..c096e99ff65b8dd0c852275ef4b82a2fea2d5ad9 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -315,12 +315,18 @@ class Record(RecordWithoutPickling): class ImmutableRecordWithoutPickling(RecordWithoutPickling): "Hashable record. Does not explicitly enforce immutability." + def __init__(self, *args, **kwargs): + RecordWithoutPickling.__init__(self, *args, **kwargs) + self._cached_hash = None def __hash__(self): - return hash( + if self._cached_hash is None: + self._cached_hash = hash( (type(self),) + tuple(getattr(self, field) for field in self.__class__.fields)) + return self._cached_hash + class ImmutableRecord(ImmutableRecordWithoutPickling, Record): pass diff --git a/pytools/lex.py b/pytools/lex.py index 2ad9cee47e6800c7477692cae9408300e7768749..ff9d8a86e645e82f23636abd3eaf79d07c16746e 100644 --- a/pytools/lex.py +++ b/pytools/lex.py @@ -148,8 +148,8 @@ class LexIterator(object): def raise_parse_error(self, msg): if self.is_at_end(): raise ParseError(msg, self.raw_string, None) - else: - raise ParseError(msg, self.raw_string, self.lexed[self.index]) + + raise ParseError(msg, self.raw_string, self.lexed[self.index]) def expected(self, what_expected): if self.is_at_end(): diff --git a/pytools/log.py b/pytools/log.py index 8d579bde1876264188f9623fe480c09e7360941c..104cc82dd2168f6cdc9faf71bd81cde134beafa1 100644 --- a/pytools/log.py +++ b/pytools/log.py @@ -404,7 +404,8 @@ class LogManager(object): # we've opened an existing database if mode == "w": raise RuntimeError("Log database '%s' already exists" % filename) - elif mode == "wu": + + if mode == "wu": # try again with a new suffix continue @@ -446,9 +447,9 @@ class LogManager(object): if self.old_showwarning is None: raise RuntimeError( "Warnings capture was disabled, but never enabled") - else: - warnings.showwarning = self.old_showwarning - self.old_showwarning = None + + warnings.showwarning = self.old_showwarning + self.old_showwarning = None def _load(self): if self.mpi_comm and self.mpi_comm.rank != self.head_rank: @@ -859,8 +860,8 @@ class LogManager(object): if self.is_parallel: raise ValueError( "must specify explicit aggregator for '%s'" % name) - else: - agg_func = lambda lst: lst[0] + + agg_func = lambda lst: lst[0] elif isinstance(dep, Lookup): assert isinstance(dep.aggregate, Variable) name = dep.aggregate.name