From 0a64c216b85daccaf386f73459e239a06f9e134e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 12 Dec 2016 16:50:44 -0600
Subject: [PATCH] Use CSE caching for mappers that may get called before the
 CSE -> assignment translation

---
 loopy/kernel/creation.py | 11 ++++++++++-
 loopy/symbolic.py        | 13 ++++++++++---
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index d1efd4a44..0a030b108 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -26,6 +26,8 @@ THE SOFTWARE.
 
 
 import numpy as np
+
+from pymbolic.mapper import CSECachingMapperMixin
 from loopy.tools import intern_frozenset_of_ids
 from loopy.symbolic import IdentityMapper, WalkMapper
 from loopy.kernel.data import (
@@ -995,7 +997,7 @@ def parse_domains(domains, defines):
 
 # {{{ guess kernel args (if requested)
 
-class IndexRankFinder(WalkMapper):
+class IndexRankFinder(CSECachingMapperMixin, WalkMapper):
     def __init__(self, arg_name):
         self.arg_name = arg_name
         self.index_ranks = []
@@ -1012,6 +1014,13 @@ class IndexRankFinder(WalkMapper):
             else:
                 self.index_ranks.append(len(expr.index))
 
+    def map_common_subexpression_uncached(self, expr):
+        if not self.visit(expr):
+            return
+
+        self.rec(expr.child)
+        self.post_visit(expr)
+
 
 class ArgumentGuesser:
     def __init__(self, domains, instructions, temporary_variables,
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index b14fba570..50c891be4 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -38,6 +38,7 @@ from pymbolic.mapper import (
         IdentityMapper as IdentityMapperBase,
         WalkMapper as WalkMapperBase,
         CallbackMapper as CallbackMapperBase,
+        CSECachingMapperMixin,
         )
 from pymbolic.mapper.evaluator import \
         EvaluationMapper as EvaluationMapperBase
@@ -113,10 +114,14 @@ class IdentityMapper(IdentityMapperBase, IdentityMapperMixin):
     pass
 
 
-class PartialEvaluationMapper(EvaluationMapperBase, IdentityMapperMixin):
+class PartialEvaluationMapper(
+        EvaluationMapperBase, CSECachingMapperMixin, IdentityMapperMixin):
     def map_variable(self, expr):
         return expr
 
+    def map_common_subexpression_uncached(self, expr):
+        return type(expr)(self.rec(expr.child), expr.prefix, expr.scope)
+
 
 class WalkMapper(WalkMapperBase):
     def map_literal(self, expr, *args):
@@ -162,8 +167,10 @@ class CombineMapper(CombineMapperBase):
     map_linear_subscript = CombineMapperBase.map_subscript
 
 
-class SubstitutionMapper(SubstitutionMapperBase, IdentityMapperMixin):
-    pass
+class SubstitutionMapper(
+        CSECachingMapperMixin, SubstitutionMapperBase, IdentityMapperMixin):
+    def map_common_subexpression_uncached(self, expr):
+        return type(expr)(self.rec(expr.child), expr.prefix, expr.scope)
 
 
 class ConstantFoldingMapper(ConstantFoldingMapperBase,
-- 
GitLab