diff --git a/pymbolic/cse.py b/pymbolic/cse.py index f4f53f743386ab5242183a0468228927df674bf3..7ebaae61ef7735d3d029474c09f0f3e672d60581 100644 --- a/pymbolic/cse.py +++ b/pymbolic/cse.py @@ -1,41 +1,59 @@ from __future__ import division import pymbolic.primitives as prim from pymbolic.mapper import IdentityMapper, WalkMapper +from pytools import memoize_method COMMUTATIVE_CLASSES = (prim.Sum, prim.Product) -def get_normalized_cse_key(node): - if isinstance(node, COMMUTATIVE_CLASSES): - return type(node), frozenset(node.children) - else: - return node +class CSERemover(IdentityMapper): + def map_common_subexpression(self, expr): + return self.rec(expr.child) + + + + +class NormalizedKeyGetter(object): + def __init__(self): + self.cse_remover = CSERemover() + + @memoize_method + def remove_cses(self, expr): + return self.cse_remover(expr) + + def __call__(self, expr): + expr = self.remove_cses(expr) + if isinstance(expr, COMMUTATIVE_CLASSES): + return type(expr), frozenset(expr.children) + else: + return expr class CSEMapper(IdentityMapper): - def __init__(self, to_eliminate): + def __init__(self, to_eliminate, get_key): self.to_eliminate = to_eliminate + self.get_key = get_key self.canonical_subexprs = {} def get_cse(self, expr, key=None): if key is None: - key = get_normalized_cse_key(expr) + key = self.get_key(expr) try: return self.canonical_subexprs[key] except KeyError: - new_expr = prim.CommonSubexpression( + new_expr = prim.wrap_in_cse( getattr(IdentityMapper, expr.mapper_method)(self, expr)) self.canonical_subexprs[key] = new_expr return new_expr def map_sum(self, expr): - key = get_normalized_cse_key(expr) + key = self.get_key(expr) if key in self.to_eliminate: result = self.get_cse(expr, key) return result @@ -49,11 +67,9 @@ class CSEMapper(IdentityMapper): map_floor_div = map_sum map_call = map_sum - def map_quotient(self, expr): - if expr in self.to_eliminate: - return self.get_cse(expr) - else: - return IdentityMapper.map_quotient(self, expr) + def map_common_subexpression(self, expr): + # don't duplicate CSEs + return prim.wrap_in_cse(self.rec(expr.child), expr.prefix) def map_substitution(self, expr): return type(expr)( @@ -65,11 +81,12 @@ class CSEMapper(IdentityMapper): class UseCountMapper(WalkMapper): - def __init__(self): + def __init__(self, get_key): self.subexpr_counts = {} + self.get_key = get_key def visit(self, expr): - key = get_normalized_cse_key(expr) + key = self.get_key(expr) if key in self.subexpr_counts: self.subexpr_counts[key] += 1 @@ -86,7 +103,8 @@ class UseCountMapper(WalkMapper): def tag_common_subexpressions(exprs): - ucm = UseCountMapper() + get_key = NormalizedKeyGetter() + ucm = UseCountMapper(get_key) if isinstance(exprs, prim.Expression): raise TypeError("exprs should be an iterable of expressions") @@ -97,7 +115,7 @@ def tag_common_subexpressions(exprs): to_eliminate = set([subexpr_key for subexpr_key, count in ucm.subexpr_counts.iteritems() if count > 1]) - cse_mapper = CSEMapper(to_eliminate) + cse_mapper = CSEMapper(to_eliminate, get_key) result = [cse_mapper(expr) for expr in exprs] return result diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 6674e351cdcb664a61d0c331158bc9f19520e5e5..dcfa09f5900d9447c28c7d6e9e7f7862b1a59aaa 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -868,6 +868,26 @@ def is_zero(value): +def wrap_in_cse(expr, prefix=None): + if isinstance(expr, Variable): + return expr + + if isinstance(expr, CommonSubexpression): + if prefix is None: + return expr + if expr.prefix is None: + return CommonSubexpression(expr.child, prefix) + + # existing prefix wins + return expr + + else: + return CommonSubexpression(expr, prefix) + + + + + def make_common_subexpression(field, prefix=None): try: from pytools.obj_array import log_shape diff --git a/pymbolic/sympy_conv.py b/pymbolic/sympy_conv.py index f656676cd70345f88ba4f2de89d50e7e8a2ef69c..049131c9da35e81f9250c60419cfe9d264f24c6e 100644 --- a/pymbolic/sympy_conv.py +++ b/pymbolic/sympy_conv.py @@ -53,8 +53,6 @@ class ToPymbolicMapper(_SympyMapper): if prim.is_zero(denom-1): return num - if isinstance(num, int) and isinstance(denom, int): - return int(num) / int(denom) return prim.Quotient(num, denom) def map_Pow(self, expr):