diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index d1efd4a44254a6d829ce63b66bb228cea1efab6a..0a030b1088bfcd6cad15146ae583e8a134b5a185 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -26,6 +26,8 @@ THE SOFTWARE.
 
 
 import numpy as np
+
+from pymbolic.mapper import CSECachingMapperMixin
 from loopy.tools import intern_frozenset_of_ids
 from loopy.symbolic import IdentityMapper, WalkMapper
 from loopy.kernel.data import (
@@ -995,7 +997,7 @@ def parse_domains(domains, defines):
 
 # {{{ guess kernel args (if requested)
 
-class IndexRankFinder(WalkMapper):
+class IndexRankFinder(CSECachingMapperMixin, WalkMapper):
     def __init__(self, arg_name):
         self.arg_name = arg_name
         self.index_ranks = []
@@ -1012,6 +1014,13 @@ class IndexRankFinder(WalkMapper):
             else:
                 self.index_ranks.append(len(expr.index))
 
+    def map_common_subexpression_uncached(self, expr):
+        if not self.visit(expr):
+            return
+
+        self.rec(expr.child)
+        self.post_visit(expr)
+
 
 class ArgumentGuesser:
     def __init__(self, domains, instructions, temporary_variables,
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index b14fba5706b83c94a86b66079925939567d60594..50c891be476810887720c4e13c9659966b431f5d 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -38,6 +38,7 @@ from pymbolic.mapper import (
         IdentityMapper as IdentityMapperBase,
         WalkMapper as WalkMapperBase,
         CallbackMapper as CallbackMapperBase,
+        CSECachingMapperMixin,
         )
 from pymbolic.mapper.evaluator import \
         EvaluationMapper as EvaluationMapperBase
@@ -113,10 +114,14 @@ class IdentityMapper(IdentityMapperBase, IdentityMapperMixin):
     pass
 
 
-class PartialEvaluationMapper(EvaluationMapperBase, IdentityMapperMixin):
+class PartialEvaluationMapper(
+        EvaluationMapperBase, CSECachingMapperMixin, IdentityMapperMixin):
     def map_variable(self, expr):
         return expr
 
+    def map_common_subexpression_uncached(self, expr):
+        return type(expr)(self.rec(expr.child), expr.prefix, expr.scope)
+
 
 class WalkMapper(WalkMapperBase):
     def map_literal(self, expr, *args):
@@ -162,8 +167,10 @@ class CombineMapper(CombineMapperBase):
     map_linear_subscript = CombineMapperBase.map_subscript
 
 
-class SubstitutionMapper(SubstitutionMapperBase, IdentityMapperMixin):
-    pass
+class SubstitutionMapper(
+        CSECachingMapperMixin, SubstitutionMapperBase, IdentityMapperMixin):
+    def map_common_subexpression_uncached(self, expr):
+        return type(expr)(self.rec(expr.child), expr.prefix, expr.scope)
 
 
 class ConstantFoldingMapper(ConstantFoldingMapperBase,