diff --git a/src/mapper/flop_counter.py b/src/mapper/flop_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..029f7282f576870763dcc40f250361cb54ba73f8 --- /dev/null +++ b/src/mapper/flop_counter.py @@ -0,0 +1,31 @@ +from pymbolic.mapper import CombineMapper + + + + +class FlopCounter(CombineMapper): + def combine(self, values): + return sum(values) + + def handle_unsupported_expression(self, expr, *args, **kwargs): + return 0 + + def map_constant(self, expr): + return 0 + + def map_variable(self, expr): + return 0 + + def map_sum(self, expr): + if expr.children: + return len(expr.children) - 1 + sum(self.rec(ch) for ch in expr.children) + else: + return 0 + + map_product = map_sum + + def map_quotient(self, expr, *args): + return 1 + self.rec(expr.numerator) + self.rec(expr.denominator) + + def map_power(self, expr, *args): + return 1 + self.rec(expr.base) + self.rec(expr.exponent)