diff --git a/pymbolic/cse.py b/pymbolic/cse.py index 9329a696c12a97975318d10e1631ab11e8dbeee2..120af263eb3c5614903c2600e5c4725b7f1cbe80 100644 --- a/pymbolic/cse.py +++ b/pymbolic/cse.py @@ -1,6 +1,6 @@ from __future__ import division import pymbolic.primitives as prim -from pymbolic.mapper import IdentityMapper, WalkMapper +from pymbolic.mapper import IdentityMapper, WalkMapper, CSECachingMapperMixin from pytools import memoize_method COMMUTATIVE_CLASSES = (prim.Sum, prim.Product) @@ -33,7 +33,44 @@ class NormalizedKeyGetter(object): -class CSEMapper(IdentityMapper): +class UseCountMapper(WalkMapper): + def __init__(self, get_key): + self.subexpr_counts = {} + self.get_key = get_key + + def visit(self, expr): + key = self.get_key(expr) + + if key in self.subexpr_counts: + self.subexpr_counts[key] += 1 + + # do not re-traverse (and thus re-count subexpressions) + return False + else: + self.subexpr_counts[key] = 1 + + # continue traversing + return True + + def map_common_subexpression(self, expr, *args, **kwargs): + # For existing CSEs, reuse has already been decided. + # Add to + + key = self.get_key(expr) + if key in self.subexpr_counts: + self.subexpr_counts[key] += 1 + else: + # This order reversal matters: Since get_key removes the outer + # CSE, need to traverse first, then add to counter. + + self.rec(expr.child) + self.subexpr_counts[key] = 1 + + + + + +class CSEMapper(IdentityMapper, CSECachingMapperMixin): def __init__(self, to_eliminate, get_key): self.to_eliminate = to_eliminate self.get_key = get_key @@ -67,11 +104,12 @@ class CSEMapper(IdentityMapper): map_floor_div = map_sum map_call = map_sum - def map_common_subexpression(self, expr): + def map_common_subexpression_uncached(self, expr): # Avoid creating CSE(CSE(...)) # # NOTE: This is not equivalent to isinstance--it's more specific, - # and for a reason! + # and for a reason! (Because we don't want to obliterate derived + # CSE types.) if type(expr) is prim.CommonSubexpression: return prim.wrap_in_cse(self.rec(expr.child), expr.prefix) else: @@ -86,28 +124,6 @@ class CSEMapper(IdentityMapper): -class UseCountMapper(WalkMapper): - def __init__(self, get_key): - self.subexpr_counts = {} - self.get_key = get_key - - def visit(self, expr): - key = self.get_key(expr) - - if key in self.subexpr_counts: - self.subexpr_counts[key] += 1 - - # do not re-traverse (and thus re-count subexpressions) - return False - else: - self.subexpr_counts[key] = 1 - - # continue traversing - return True - - - - def tag_common_subexpressions(exprs): get_key = NormalizedKeyGetter() ucm = UseCountMapper(get_key) @@ -121,6 +137,7 @@ def tag_common_subexpressions(exprs): to_eliminate = set([subexpr_key for subexpr_key, count in ucm.subexpr_counts.iteritems() if count > 1]) + cse_mapper = CSEMapper(to_eliminate, get_key) result = [cse_mapper(expr) for expr in exprs] return result