From e4f1172a56a9b44ef787b2e3b422e2d5f4d02146 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 1 Jul 2009 14:05:20 -0400
Subject: [PATCH] Add CSECachingMapperMixin. Use it to reduce complexity of
 several mappers.

---
 pymbolic/mapper/__init__.py        | 18 +++++++++++++++++
 pymbolic/mapper/constant_folder.py | 32 +++++++++++++++++++++++-------
 pymbolic/mapper/dependency.py      |  6 +++---
 pymbolic/primitives.py             |  2 +-
 4 files changed, 47 insertions(+), 11 deletions(-)

diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py
index c549938..e4f1dc1 100644
--- a/pymbolic/mapper/__init__.py
+++ b/pymbolic/mapper/__init__.py
@@ -210,3 +210,21 @@ class IdentityMapper(IdentityMapperBase, RecursiveMapper):
 
 class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper):
     pass
+
+
+
+
+class CSECachingMapperMixin(object):
+    def map_common_subexpression(self, expr):
+        try:
+            ccd = self._cse_cache_dict
+        except AttributeError:
+            from weakref import WeakKeyDictionary
+            ccd = self._cse_cache_dict = WeakKeyDictionary()
+
+        try:
+            return ccd[expr]
+        except KeyError:
+            result = self.map_common_subexpression_uncached(expr)
+            ccd[expr] = result
+            return result
diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py
index 680375f..8e6cf07 100644
--- a/pymbolic/mapper/constant_folder.py
+++ b/pymbolic/mapper/constant_folder.py
@@ -1,4 +1,7 @@
-from pymbolic.mapper import IdentityMapper, NonrecursiveIdentityMapper
+from pymbolic.mapper import \
+        IdentityMapper, \
+        NonrecursiveIdentityMapper, \
+        CSECachingMapperMixin
 
 
 
@@ -61,20 +64,35 @@ class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase):
 
 
 
-class ConstantFoldingMapper(ConstantFoldingMapperBase, IdentityMapper):
-    pass
+class ConstantFoldingMapper(
+        CSECachingMapperMixin, 
+        ConstantFoldingMapperBase, 
+        IdentityMapper):
+
+    map_common_subexpression_uncached = \
+            IdentityMapper.map_common_subexpression
 
 class NonrecursiveConstantFoldingMapper(
+        CSECachingMapperMixin,
         NonrecursiveIdentityMapper, 
         ConstantFoldingMapperBase):
-    pass
 
-class CommutativeConstantFoldingMapper(CommutativeConstantFoldingMapperBase,
+    map_common_subexpression_uncached = \
+            NonrecursiveIdentityMapper.map_common_subexpression
+
+class CommutativeConstantFoldingMapper(
+        CSECachingMapperMixin,
+        CommutativeConstantFoldingMapperBase,
         IdentityMapper):
-    pass
+
+    map_common_subexpression_uncached = \
+            IdentityMapper.map_common_subexpression
 
 class NonrecursiveCommutativeConstantFoldingMapper(
+        CSECachingMapperMixin,
         CommutativeConstantFoldingMapperBase,
         NonrecursiveIdentityMapper,):
-    pass
+
+    map_common_subexpression_uncached = \
+            NonrecursiveIdentityMapper.map_common_subexpression
 
diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py
index e299c00..832cf53 100644
--- a/pymbolic/mapper/dependency.py
+++ b/pymbolic/mapper/dependency.py
@@ -1,9 +1,9 @@
-from pymbolic.mapper import CombineMapper
+from pymbolic.mapper import CombineMapper, CSECachingMapperMixin
 
 
 
 
-class DependencyMapper(CombineMapper):
+class DependencyMapper(CSECachingMapperMixin, CombineMapper):
     def __init__(self, 
             include_subscripts=True, 
             include_lookups=True,
@@ -62,7 +62,7 @@ class DependencyMapper(CombineMapper):
         else:
             return CombineMapper.map_subscript(self, expr)
 
-    def map_common_subexpression(self, expr):
+    def map_common_subexpression_uncached(self, expr):
         if self.include_cses:
             return set([expr])
         else:
diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index 7a654f1..64bd06f 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -156,7 +156,7 @@ class Expression(object):
         Subclasses should generally not override this method, but instead 
         provide an implementation of L{is_equal}.
         """
-        if id(self) == id(other):
+        if self is other:
             return True
         elif hash(self) != hash(other):
             return False
-- 
GitLab