diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 0b59c975f1266f6ba8633dc8c19cc4d990c36e48..f00f48e0adcd7975863473e8db749eff0ee855d5 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -134,26 +134,40 @@ class Mapper: implementations. """ - try: - method = getattr(self, expr.mapper_method) - except AttributeError: - if isinstance(expr, primitives.Expression): - for cls in type(expr).__mro__[1:]: - method_name = getattr(cls, "mapper_method", None) - if method_name: - method = getattr(self, method_name, None) - if method: - break - else: - return self.handle_unsupported_expression( - expr, *args, **kwargs) + method_name = getattr(expr, "mapper_method", None) + if method_name is not None: + method = getattr(self, method_name, None) + if method is not None: + result = method(expr, *args, **kwargs) + return result + + if isinstance(expr, primitives.Expression): + for cls in type(expr).__mro__[1:]: + method_name = getattr(cls, "mapper_method", None) + if method_name: + method = getattr(self, method_name, None) + if method: + return method(expr, *args, **kwargs) else: - return self.map_foreign(expr, *args, **kwargs) - - return method(expr, *args, **kwargs) + return self.handle_unsupported_expression(expr, *args, **kwargs) + else: + return self.map_foreign(expr, *args, **kwargs) rec = __call__ + def rec_fallback(self, expr, *args, **kwargs): + if isinstance(expr, primitives.Expression): + for cls in type(expr).__mro__[1:]: + method_name = getattr(cls, "mapper_method", None) + if method_name: + method = getattr(self, method_name, None) + if method: + return method(expr, *args, **kwargs) + else: + return self.handle_unsupported_expression(expr, *args, **kwargs) + else: + return self.map_foreign(expr, *args, **kwargs) + def map_algebraic_leaf(self, expr, *args, **kwargs): raise NotImplementedError @@ -210,6 +224,9 @@ class Mapper: self.__class__, repr(expr))) +_NOT_IN_CACHE = object() + + class CachedMapper(Mapper): """ A mapper that memoizes the mapped result for the expressions traversed. @@ -237,14 +254,24 @@ class CachedMapper(Mapper): return (type(expr), expr, args, tuple(sorted(kwargs.items()))) def __call__(self, expr, *args, **kwargs): - cache_key = self.get_cache_key(expr, *args, **kwargs) - try: - return self._cache[cache_key] - except KeyError: - result = super().rec(expr, *args, **kwargs) - self._cache[cache_key] = result + result = self._cache.get( + (cache_key := self.get_cache_key(expr, *args, **kwargs)), + _NOT_IN_CACHE) + if result is not _NOT_IN_CACHE: return result + method_name = getattr(expr, "mapper_method", None) + if method_name is not None: + method = getattr(self, method_name, None) + if method is not None: + result = method(expr, *args, **kwargs) + self._cache[cache_key] = result + return result + + result = self.rec_fallback(expr, *args, **kwargs) + self._cache[cache_key] = result + return result + rec = __call__ # }}} @@ -572,12 +599,12 @@ class IdentityMapper(Mapper): return [self.rec(child, *args, **kwargs) for child in expr] def map_tuple(self, expr, *args, **kwargs): - children = tuple([self.rec(child, *args, **kwargs) for child in expr]) + children = [self.rec(child, *args, **kwargs) for child in expr] if all(child is orig_child for child, orig_child in zip(children, expr)): return expr - return children + return tuple(children) def map_numpy_array(self, expr, *args, **kwargs): import numpy @@ -995,8 +1022,21 @@ class CachingMapperMixin: return self.result_cache[expr] except TypeError: # not hashable, oh well + method_name = getattr(expr, "mapper_method", None) + if method_name is not None: + method = getattr(self, method_name, None) + if method is not None: + return method(expr, ) return super().rec(expr) except KeyError: + method_name = getattr(expr, "mapper_method", None) + if method_name is not None: + method = getattr(self, method_name, None) + if method is not None: + result = method(expr, ) + self.result_cache[expr] = result + return result + result = super().rec(expr) self.result_cache[expr] = result return result