diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index c549938d13b62438593b153b34b9fa838b0b385a..e4f1dc1f800f4fcde3ffc6749b4d3dc3e914f596 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -210,3 +210,21 @@ class IdentityMapper(IdentityMapperBase, RecursiveMapper): class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper): pass + + + + +class CSECachingMapperMixin(object): + def map_common_subexpression(self, expr): + try: + ccd = self._cse_cache_dict + except AttributeError: + from weakref import WeakKeyDictionary + ccd = self._cse_cache_dict = WeakKeyDictionary() + + try: + return ccd[expr] + except KeyError: + result = self.map_common_subexpression_uncached(expr) + ccd[expr] = result + return result diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py index 680375f59a69ac1f69e0430a5c6d86283492d3a5..8e6cf0704859c7a34a673046cab8a4c4af873f3e 100644 --- a/pymbolic/mapper/constant_folder.py +++ b/pymbolic/mapper/constant_folder.py @@ -1,4 +1,7 @@ -from pymbolic.mapper import IdentityMapper, NonrecursiveIdentityMapper +from pymbolic.mapper import \ + IdentityMapper, \ + NonrecursiveIdentityMapper, \ + CSECachingMapperMixin @@ -61,20 +64,35 @@ class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): -class ConstantFoldingMapper(ConstantFoldingMapperBase, IdentityMapper): - pass +class ConstantFoldingMapper( + CSECachingMapperMixin, + ConstantFoldingMapperBase, + IdentityMapper): + + map_common_subexpression_uncached = \ + IdentityMapper.map_common_subexpression class NonrecursiveConstantFoldingMapper( + CSECachingMapperMixin, NonrecursiveIdentityMapper, ConstantFoldingMapperBase): - pass -class CommutativeConstantFoldingMapper(CommutativeConstantFoldingMapperBase, + map_common_subexpression_uncached = \ + NonrecursiveIdentityMapper.map_common_subexpression + +class CommutativeConstantFoldingMapper( + CSECachingMapperMixin, + CommutativeConstantFoldingMapperBase, IdentityMapper): - pass + + map_common_subexpression_uncached = \ + IdentityMapper.map_common_subexpression class NonrecursiveCommutativeConstantFoldingMapper( + CSECachingMapperMixin, CommutativeConstantFoldingMapperBase, NonrecursiveIdentityMapper,): - pass + + map_common_subexpression_uncached = \ + NonrecursiveIdentityMapper.map_common_subexpression diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index e299c004848dd7234ef01d95b6387fa05472ccf0..832cf53c68c2d19dbe6922900784020bf943f6e1 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -1,9 +1,9 @@ -from pymbolic.mapper import CombineMapper +from pymbolic.mapper import CombineMapper, CSECachingMapperMixin -class DependencyMapper(CombineMapper): +class DependencyMapper(CSECachingMapperMixin, CombineMapper): def __init__(self, include_subscripts=True, include_lookups=True, @@ -62,7 +62,7 @@ class DependencyMapper(CombineMapper): else: return CombineMapper.map_subscript(self, expr) - def map_common_subexpression(self, expr): + def map_common_subexpression_uncached(self, expr): if self.include_cses: return set([expr]) else: diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 7a654f1d482aeaa2edb9cc07df14c2837897de48..64bd06f1dc2a6bcd3fb1e94b2f55ab34900ce7e7 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -156,7 +156,7 @@ class Expression(object): Subclasses should generally not override this method, but instead provide an implementation of L{is_equal}. """ - if id(self) == id(other): + if self is other: return True elif hash(self) != hash(other): return False