From 0a64c216b85daccaf386f73459e239a06f9e134e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 12 Dec 2016 16:50:44 -0600 Subject: [PATCH] Use CSE caching for mappers that may get called before the CSE -> assignment translation --- loopy/kernel/creation.py | 11 ++++++++++- loopy/symbolic.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index d1efd4a44..0a030b108 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -26,6 +26,8 @@ THE SOFTWARE. import numpy as np + +from pymbolic.mapper import CSECachingMapperMixin from loopy.tools import intern_frozenset_of_ids from loopy.symbolic import IdentityMapper, WalkMapper from loopy.kernel.data import ( @@ -995,7 +997,7 @@ def parse_domains(domains, defines): # {{{ guess kernel args (if requested) -class IndexRankFinder(WalkMapper): +class IndexRankFinder(CSECachingMapperMixin, WalkMapper): def __init__(self, arg_name): self.arg_name = arg_name self.index_ranks = [] @@ -1012,6 +1014,13 @@ class IndexRankFinder(WalkMapper): else: self.index_ranks.append(len(expr.index)) + def map_common_subexpression_uncached(self, expr): + if not self.visit(expr): + return + + self.rec(expr.child) + self.post_visit(expr) + class ArgumentGuesser: def __init__(self, domains, instructions, temporary_variables, diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b14fba570..50c891be4 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -38,6 +38,7 @@ from pymbolic.mapper import ( IdentityMapper as IdentityMapperBase, WalkMapper as WalkMapperBase, CallbackMapper as CallbackMapperBase, + CSECachingMapperMixin, ) from pymbolic.mapper.evaluator import \ EvaluationMapper as EvaluationMapperBase @@ -113,10 +114,14 @@ class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass -class PartialEvaluationMapper(EvaluationMapperBase, IdentityMapperMixin): +class PartialEvaluationMapper( + EvaluationMapperBase, CSECachingMapperMixin, IdentityMapperMixin): def map_variable(self, expr): return expr + def map_common_subexpression_uncached(self, expr): + return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) + class WalkMapper(WalkMapperBase): def map_literal(self, expr, *args): @@ -162,8 +167,10 @@ class CombineMapper(CombineMapperBase): map_linear_subscript = CombineMapperBase.map_subscript -class SubstitutionMapper(SubstitutionMapperBase, IdentityMapperMixin): - pass +class SubstitutionMapper( + CSECachingMapperMixin, SubstitutionMapperBase, IdentityMapperMixin): + def map_common_subexpression_uncached(self, expr): + return type(expr)(self.rec(expr.child), expr.prefix, expr.scope) class ConstantFoldingMapper(ConstantFoldingMapperBase, -- GitLab