Skip to content
Snippets Groups Projects
Commit ca628341 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Various CSE fixes.

parent a911877b
No related branches found
No related tags found
No related merge requests found
from __future__ import division from __future__ import division
import pymbolic.primitives as prim import pymbolic.primitives as prim
from pymbolic.mapper import IdentityMapper, WalkMapper from pymbolic.mapper import IdentityMapper, WalkMapper
from pytools import memoize_method
COMMUTATIVE_CLASSES = (prim.Sum, prim.Product) COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)
def get_normalized_cse_key(node): class CSERemover(IdentityMapper):
if isinstance(node, COMMUTATIVE_CLASSES): def map_common_subexpression(self, expr):
return type(node), frozenset(node.children) return self.rec(expr.child)
else:
return node
class NormalizedKeyGetter(object):
def __init__(self):
self.cse_remover = CSERemover()
@memoize_method
def remove_cses(self, expr):
return self.cse_remover(expr)
def __call__(self, expr):
expr = self.remove_cses(expr)
if isinstance(expr, COMMUTATIVE_CLASSES):
return type(expr), frozenset(expr.children)
else:
return expr
class CSEMapper(IdentityMapper): class CSEMapper(IdentityMapper):
def __init__(self, to_eliminate): def __init__(self, to_eliminate, get_key):
self.to_eliminate = to_eliminate self.to_eliminate = to_eliminate
self.get_key = get_key
self.canonical_subexprs = {} self.canonical_subexprs = {}
def get_cse(self, expr, key=None): def get_cse(self, expr, key=None):
if key is None: if key is None:
key = get_normalized_cse_key(expr) key = self.get_key(expr)
try: try:
return self.canonical_subexprs[key] return self.canonical_subexprs[key]
except KeyError: except KeyError:
new_expr = prim.CommonSubexpression( new_expr = prim.wrap_in_cse(
getattr(IdentityMapper, expr.mapper_method)(self, expr)) getattr(IdentityMapper, expr.mapper_method)(self, expr))
self.canonical_subexprs[key] = new_expr self.canonical_subexprs[key] = new_expr
return new_expr return new_expr
def map_sum(self, expr): def map_sum(self, expr):
key = get_normalized_cse_key(expr) key = self.get_key(expr)
if key in self.to_eliminate: if key in self.to_eliminate:
result = self.get_cse(expr, key) result = self.get_cse(expr, key)
return result return result
...@@ -49,11 +67,9 @@ class CSEMapper(IdentityMapper): ...@@ -49,11 +67,9 @@ class CSEMapper(IdentityMapper):
map_floor_div = map_sum map_floor_div = map_sum
map_call = map_sum map_call = map_sum
def map_quotient(self, expr): def map_common_subexpression(self, expr):
if expr in self.to_eliminate: # don't duplicate CSEs
return self.get_cse(expr) return prim.wrap_in_cse(self.rec(expr.child), expr.prefix)
else:
return IdentityMapper.map_quotient(self, expr)
def map_substitution(self, expr): def map_substitution(self, expr):
return type(expr)( return type(expr)(
...@@ -65,11 +81,12 @@ class CSEMapper(IdentityMapper): ...@@ -65,11 +81,12 @@ class CSEMapper(IdentityMapper):
class UseCountMapper(WalkMapper): class UseCountMapper(WalkMapper):
def __init__(self): def __init__(self, get_key):
self.subexpr_counts = {} self.subexpr_counts = {}
self.get_key = get_key
def visit(self, expr): def visit(self, expr):
key = get_normalized_cse_key(expr) key = self.get_key(expr)
if key in self.subexpr_counts: if key in self.subexpr_counts:
self.subexpr_counts[key] += 1 self.subexpr_counts[key] += 1
...@@ -86,7 +103,8 @@ class UseCountMapper(WalkMapper): ...@@ -86,7 +103,8 @@ class UseCountMapper(WalkMapper):
def tag_common_subexpressions(exprs): def tag_common_subexpressions(exprs):
ucm = UseCountMapper() get_key = NormalizedKeyGetter()
ucm = UseCountMapper(get_key)
if isinstance(exprs, prim.Expression): if isinstance(exprs, prim.Expression):
raise TypeError("exprs should be an iterable of expressions") raise TypeError("exprs should be an iterable of expressions")
...@@ -97,7 +115,7 @@ def tag_common_subexpressions(exprs): ...@@ -97,7 +115,7 @@ def tag_common_subexpressions(exprs):
to_eliminate = set([subexpr_key to_eliminate = set([subexpr_key
for subexpr_key, count in ucm.subexpr_counts.iteritems() for subexpr_key, count in ucm.subexpr_counts.iteritems()
if count > 1]) if count > 1])
cse_mapper = CSEMapper(to_eliminate) cse_mapper = CSEMapper(to_eliminate, get_key)
result = [cse_mapper(expr) for expr in exprs] result = [cse_mapper(expr) for expr in exprs]
return result return result
...@@ -868,6 +868,26 @@ def is_zero(value): ...@@ -868,6 +868,26 @@ def is_zero(value):
def wrap_in_cse(expr, prefix=None):
if isinstance(expr, Variable):
return expr
if isinstance(expr, CommonSubexpression):
if prefix is None:
return expr
if expr.prefix is None:
return CommonSubexpression(expr.child, prefix)
# existing prefix wins
return expr
else:
return CommonSubexpression(expr, prefix)
def make_common_subexpression(field, prefix=None): def make_common_subexpression(field, prefix=None):
try: try:
from pytools.obj_array import log_shape from pytools.obj_array import log_shape
......
...@@ -53,8 +53,6 @@ class ToPymbolicMapper(_SympyMapper): ...@@ -53,8 +53,6 @@ class ToPymbolicMapper(_SympyMapper):
if prim.is_zero(denom-1): if prim.is_zero(denom-1):
return num return num
if isinstance(num, int) and isinstance(denom, int):
return int(num) / int(denom)
return prim.Quotient(num, denom) return prim.Quotient(num, denom)
def map_Pow(self, expr): def map_Pow(self, expr):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment