diff --git a/doc/mappers.rst b/doc/mappers.rst index a3c60b6f4c143a9278f5704c641db36305abc170..c41842760ade982ac39a386739577be6a3907133 100644 --- a/doc/mappers.rst +++ b/doc/mappers.rst @@ -55,5 +55,6 @@ Finding expression properties .. automodule:: pymbolic.mapper.flop_counter .. autoclass:: FlopCounter +.. autoclass:: CSEAwareFlopCounter .. vim: sw=4 diff --git a/pymbolic/mapper/flop_counter.py b/pymbolic/mapper/flop_counter.py index 445b04ebc2c9b34d701eb0ba23da9c607febd9bd..fc8c92f3cb6d0a96b78433773e981ffb3b060078 100644 --- a/pymbolic/mapper/flop_counter.py +++ b/pymbolic/mapper/flop_counter.py @@ -22,10 +22,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from pymbolic.mapper import CombineMapper +from pymbolic.mapper import CombineMapper, CachingMapperMixin -class FlopCounter(CombineMapper): +class FlopCounterBase(CombineMapper): def combine(self, values): return sum(values) @@ -55,3 +55,28 @@ class FlopCounter(CombineMapper): return self.rec(expr.criterion) + max( self.rec(expr.then), self.rec(expr.else_)) + + +class FlopCounter(FlopCounterBase, CachingMapperMixin): + def map_common_subexpression_uncached(self, expr): + return self.rec(expr.child) + + +class CSEAwareFlopCounter(FlopCounterBase): + """A flop counter that only counts the contribution from common + subexpressions once. + + .. warning:: + + You must use a fresh mapper for each new evaluation operation for which + reuse may take place. + """ + def __init__(self): + self.cse_seen_set = set() + + def map_common_subexpression(self, expr): + if expr in self.cse_seen_set: + return 0 + else: + self.cse_seen_set.add(expr) + return self.rec(expr.child) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index bb92ac841a8c2be2cce9e4a704561408fd4e2b09..a049cd2d33098a8dff11fbab51c2c8d3ff48c29d 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -560,6 +560,20 @@ def test_latex_mapper(): shutil.rmtree(latex_dir) +def test_flop_counter(): + x = prim.Variable("x") + y = prim.Variable("y") + z = prim.Variable("z") + + subexpr = prim.CommonSubexpression(3 * (x**2 + y + z)) + expr = 3*subexpr + subexpr + + from pymbolic.mapper.flop_counter import FlopCounter, CSEAwareFlopCounter + assert FlopCounter()(expr) == 4 * 2 + 2 + + assert CSEAwareFlopCounter()(expr) == 4 + 2 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: