From a911877b81a82a59fbc735ad604c01461429d1ab Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 27 Mar 2012 00:42:31 -0400 Subject: [PATCH] Add common subexpression finder. --- pymbolic/cse.py | 103 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 pymbolic/cse.py diff --git a/pymbolic/cse.py b/pymbolic/cse.py new file mode 100644 index 0000000..f4f53f7 --- /dev/null +++ b/pymbolic/cse.py @@ -0,0 +1,103 @@ +from __future__ import division +import pymbolic.primitives as prim +from pymbolic.mapper import IdentityMapper, WalkMapper + +COMMUTATIVE_CLASSES = (prim.Sum, prim.Product) + + + + +def get_normalized_cse_key(node): + if isinstance(node, COMMUTATIVE_CLASSES): + return type(node), frozenset(node.children) + else: + return node + + + + +class CSEMapper(IdentityMapper): + def __init__(self, to_eliminate): + self.to_eliminate = to_eliminate + + self.canonical_subexprs = {} + + def get_cse(self, expr, key=None): + if key is None: + key = get_normalized_cse_key(expr) + + try: + return self.canonical_subexprs[key] + except KeyError: + new_expr = prim.CommonSubexpression( + getattr(IdentityMapper, expr.mapper_method)(self, expr)) + self.canonical_subexprs[key] = new_expr + return new_expr + + def map_sum(self, expr): + key = get_normalized_cse_key(expr) + if key in self.to_eliminate: + result = self.get_cse(expr, key) + return result + else: + return getattr(IdentityMapper, expr.mapper_method)(self, expr) + + map_product = map_sum + map_power = map_sum + map_quotient = map_sum + map_remainder = map_sum + map_floor_div = map_sum + map_call = map_sum + + def map_quotient(self, expr): + if expr in self.to_eliminate: + return self.get_cse(expr) + else: + return IdentityMapper.map_quotient(self, expr) + + def map_substitution(self, expr): + return type(expr)( + expr.child, + expr.variables, + tuple(self.rec(v) for v in expr.values)) + + + + +class UseCountMapper(WalkMapper): + def __init__(self): + self.subexpr_counts = {} + + def visit(self, expr): + key = get_normalized_cse_key(expr) + + if key in self.subexpr_counts: + self.subexpr_counts[key] += 1 + + # do not re-traverse (and thus re-count subexpressions) + return False + else: + self.subexpr_counts[key] = 1 + + # continue traversing + return True + + + + +def tag_common_subexpressions(exprs): + ucm = UseCountMapper() + + if isinstance(exprs, prim.Expression): + raise TypeError("exprs should be an iterable of expressions") + + for expr in exprs: + ucm(expr) + + to_eliminate = set([subexpr_key + for subexpr_key, count in ucm.subexpr_counts.iteritems() + if count > 1]) + cse_mapper = CSEMapper(to_eliminate) + result = [cse_mapper(expr) for expr in exprs] + return result + -- GitLab