From bd8a3b56e965cadac36f2aff2216da68ebaf2708 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 27 Aug 2018 21:09:18 -0500 Subject: [PATCH] Fix interaction between CSEs and flop counts, add CSEAwareFlopCounter --- doc/mappers.rst | 1 + pymbolic/mapper/flop_counter.py | 29 +++++++++++++++++++++++++++-- test/test_pymbolic.py | 14 ++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/doc/mappers.rst b/doc/mappers.rst index a3c60b6..c418427 100644 --- a/doc/mappers.rst +++ b/doc/mappers.rst @@ -55,5 +55,6 @@ Finding expression properties .. automodule:: pymbolic.mapper.flop_counter .. autoclass:: FlopCounter +.. autoclass:: CSEAwareFlopCounter .. vim: sw=4 diff --git a/pymbolic/mapper/flop_counter.py b/pymbolic/mapper/flop_counter.py index 445b04e..fc8c92f 100644 --- a/pymbolic/mapper/flop_counter.py +++ b/pymbolic/mapper/flop_counter.py @@ -22,10 +22,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from pymbolic.mapper import CombineMapper +from pymbolic.mapper import CombineMapper, CachingMapperMixin -class FlopCounter(CombineMapper): +class FlopCounterBase(CombineMapper): def combine(self, values): return sum(values) @@ -55,3 +55,28 @@ class FlopCounter(CombineMapper): return self.rec(expr.criterion) + max( self.rec(expr.then), self.rec(expr.else_)) + + +class FlopCounter(FlopCounterBase, CachingMapperMixin): + def map_common_subexpression_uncached(self, expr): + return self.rec(expr.child) + + +class CSEAwareFlopCounter(FlopCounterBase): + """A flop counter that only counts the contribution from common + subexpressions once. + + .. warning:: + + You must use a fresh mapper for each new evaluation operation for which + reuse may take place. + """ + def __init__(self): + self.cse_seen_set = set() + + def map_common_subexpression(self, expr): + if expr in self.cse_seen_set: + return 0 + else: + self.cse_seen_set.add(expr) + return self.rec(expr.child) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index bb92ac8..a049cd2 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -560,6 +560,20 @@ def test_latex_mapper(): shutil.rmtree(latex_dir) +def test_flop_counter(): + x = prim.Variable("x") + y = prim.Variable("y") + z = prim.Variable("z") + + subexpr = prim.CommonSubexpression(3 * (x**2 + y + z)) + expr = 3*subexpr + subexpr + + from pymbolic.mapper.flop_counter import FlopCounter, CSEAwareFlopCounter + assert FlopCounter()(expr) == 4 * 2 + 2 + + assert CSEAwareFlopCounter()(expr) == 4 + 2 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab