diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index d1efd4a44254a6d829ce63b66bb228cea1efab6a..0a030b1088bfcd6cad15146ae583e8a134b5a185 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -26,6 +26,8 @@ THE SOFTWARE. import numpy as np + +from pymbolic.mapper import CSECachingMapperMixin from loopy.tools import intern_frozenset_of_ids from loopy.symbolic import IdentityMapper, WalkMapper from loopy.kernel.data import ( @@ -995,7 +997,7 @@ def parse_domains(domains, defines): # {{{ guess kernel args (if requested) -class IndexRankFinder(WalkMapper): +class IndexRankFinder(CSECachingMapperMixin, WalkMapper): def __init__(self, arg_name): self.arg_name = arg_name self.index_ranks = [] @@ -1012,6 +1014,13 @@ class IndexRankFinder(WalkMapper): else: self.index_ranks.append(len(expr.index)) + def map_common_subexpression_uncached(self, expr): + if not self.visit(expr): + return + + self.rec(expr.child) + self.post_visit(expr) + class ArgumentGuesser: def __init__(self, domains, instructions, temporary_variables, diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b14fba5706b83c94a86b66079925939567d60594..50c891be476810887720c4e13c9659966b431f5d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -38,6 +38,7 @@ from pymbolic.mapper import ( IdentityMapper as IdentityMapperBase, WalkMapper as WalkMapperBase, CallbackMapper as CallbackMapperBase, + CSECachingMapperMixin, ) from pymbolic.mapper.evaluator import \ EvaluationMapper as EvaluationMapperBase @@ -113,10 +114,14 @@ class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass -class PartialEvaluationMapper(EvaluationMapperBase, IdentityMapperMixin): +class PartialEvaluationMapper( + EvaluationMapperBase, CSECachingMapperMixin, IdentityMapperMixin): def map_variable(self, expr): return expr + def map_common_subexpression_uncached(self, expr): + return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) + class WalkMapper(WalkMapperBase): def map_literal(self, expr, *args): @@ -162,8 +167,10 @@ class CombineMapper(CombineMapperBase): map_linear_subscript = CombineMapperBase.map_subscript -class SubstitutionMapper(SubstitutionMapperBase, IdentityMapperMixin): - pass +class SubstitutionMapper( + CSECachingMapperMixin, SubstitutionMapperBase, IdentityMapperMixin): + def map_common_subexpression_uncached(self, expr): + return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) class ConstantFoldingMapper(ConstantFoldingMapperBase,