diff --git a/pymbolic/cse.py b/pymbolic/cse.py
index 9329a696c12a97975318d10e1631ab11e8dbeee2..120af263eb3c5614903c2600e5c4725b7f1cbe80 100644
--- a/pymbolic/cse.py
+++ b/pymbolic/cse.py
@@ -1,6 +1,6 @@
 from __future__ import division
 import pymbolic.primitives as prim
-from pymbolic.mapper import IdentityMapper, WalkMapper
+from pymbolic.mapper import IdentityMapper, WalkMapper, CSECachingMapperMixin
 from pytools import memoize_method
 
 COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)
@@ -33,7 +33,44 @@ class NormalizedKeyGetter(object):
 
 
 
-class CSEMapper(IdentityMapper):
+class UseCountMapper(WalkMapper):
+    def __init__(self, get_key):
+        self.subexpr_counts = {}
+        self.get_key = get_key
+
+    def visit(self, expr):
+        key = self.get_key(expr)
+
+        if key in self.subexpr_counts:
+            self.subexpr_counts[key] += 1
+
+            # do not re-traverse (and thus re-count subexpressions)
+            return False
+        else:
+            self.subexpr_counts[key] = 1
+
+            # continue traversing
+            return True
+
+    def map_common_subexpression(self, expr, *args, **kwargs):
+        # For existing CSEs, reuse has already been decided.
+        # Add to
+
+        key = self.get_key(expr)
+        if key in self.subexpr_counts:
+            self.subexpr_counts[key] += 1
+        else:
+            # This order reversal matters: Since get_key removes the outer
+            # CSE, need to traverse first, then add to counter.
+
+            self.rec(expr.child)
+            self.subexpr_counts[key] = 1
+
+
+
+
+
+class CSEMapper(IdentityMapper, CSECachingMapperMixin):
     def __init__(self, to_eliminate, get_key):
         self.to_eliminate = to_eliminate
         self.get_key = get_key
@@ -67,11 +104,12 @@ class CSEMapper(IdentityMapper):
     map_floor_div = map_sum
     map_call = map_sum
 
-    def map_common_subexpression(self, expr):
+    def map_common_subexpression_uncached(self, expr):
         # Avoid creating CSE(CSE(...))
         #
         # NOTE: This is not equivalent to isinstance--it's more specific,
-        # and for a reason!
+        # and for a reason! (Because we don't want to obliterate derived
+        # CSE types.)
         if type(expr) is prim.CommonSubexpression:
             return prim.wrap_in_cse(self.rec(expr.child), expr.prefix)
         else:
@@ -86,28 +124,6 @@ class CSEMapper(IdentityMapper):
 
 
 
-class UseCountMapper(WalkMapper):
-    def __init__(self, get_key):
-        self.subexpr_counts = {}
-        self.get_key = get_key
-
-    def visit(self, expr):
-        key = self.get_key(expr)
-
-        if key in self.subexpr_counts:
-            self.subexpr_counts[key] += 1
-
-            # do not re-traverse (and thus re-count subexpressions)
-            return False
-        else:
-            self.subexpr_counts[key] = 1
-
-            # continue traversing
-            return True
-
-
-
-
 def tag_common_subexpressions(exprs):
     get_key = NormalizedKeyGetter()
     ucm = UseCountMapper(get_key)
@@ -121,6 +137,7 @@ def tag_common_subexpressions(exprs):
     to_eliminate = set([subexpr_key
         for subexpr_key, count in ucm.subexpr_counts.iteritems()
         if count > 1])
+
     cse_mapper = CSEMapper(to_eliminate, get_key)
     result = [cse_mapper(expr) for expr in exprs]
     return result