diff --git a/loopy/statistics.py b/loopy/statistics.py index 96bf511b59948f4e56b22b078ff2d08c079bb498..5855f08521b79fc4ecea0ba1ba0d74b563d371e3 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -105,6 +105,21 @@ def stringify_stats_mapping(m): return result +class TypedOp: + + def __init__(self, dtype, name): + self.dtype = dtype + self.name = name + + def __eq__(self, other): + return isinstance(other, TypedOp) and ( + other.dtype == self.dtype and + other.name == self.name ) + + def __hash__(self): + return hash(str(self.dtype)+self.name) + + class StridedGmemAccess: def __init__(self, dtype, stride, direction=None): @@ -151,7 +166,7 @@ class ExpressionOpCounter(CombineMapper): def map_call(self, expr): return ToCountMap( - {(self.type_inf(expr), 'func:'+str(expr.function)): 1} + {TypedOp(self.type_inf(expr), 'func:'+str(expr.function)): 1} ) + self.rec(expr.parameters) # def map_call_with_kwargs(self, expr): # implemented in CombineMapper @@ -164,20 +179,20 @@ class ExpressionOpCounter(CombineMapper): def map_sum(self, expr): assert expr.children return ToCountMap( - {(self.type_inf(expr), 'add'): len(expr.children)-1} + {TypedOp(self.type_inf(expr), 'add'): len(expr.children)-1} ) + sum(self.rec(child) for child in expr.children) def map_product(self, expr): from pymbolic.primitives import is_zero assert expr.children - return sum(ToCountMap({(self.type_inf(expr), 'mul'): 1}) + return sum(ToCountMap({TypedOp(self.type_inf(expr), 'mul'): 1}) + self.rec(child) for child in expr.children if not is_zero(child + 1)) + \ - ToCountMap({(self.type_inf(expr), 'mul'): -1}) + ToCountMap({TypedOp(self.type_inf(expr), 'mul'): -1}) def map_quotient(self, expr, *args): - return ToCountMap({(self.type_inf(expr), 'div'): 1}) \ + return ToCountMap({TypedOp(self.type_inf(expr), 'div'): 1}) \ + self.rec(expr.numerator) \ + self.rec(expr.denominator) @@ -185,24 +200,24 @@ class ExpressionOpCounter(CombineMapper): map_remainder = map_quotient def map_power(self, expr): - return ToCountMap({(self.type_inf(expr), 'pow'): 1}) \ + return ToCountMap({TypedOp(self.type_inf(expr), 'pow'): 1}) \ + self.rec(expr.base) \ + self.rec(expr.exponent) def map_left_shift(self, expr): - return ToCountMap({(self.type_inf(expr), 'shift'): 1}) \ + return ToCountMap({TypedOp(self.type_inf(expr), 'shift'): 1}) \ + self.rec(expr.shiftee) \ + self.rec(expr.shift) map_right_shift = map_left_shift def map_bitwise_not(self, expr): - return ToCountMap({(self.type_inf(expr), 'bw'): 1}) \ + return ToCountMap({TypedOp(self.type_inf(expr), 'bw'): 1}) \ + self.rec(expr.child) def map_bitwise_or(self, expr): return ToCountMap( - {(self.type_inf(expr), 'bw'): len(expr.children)-1} + {TypedOp(self.type_inf(expr), 'bw'): len(expr.children)-1} ) + sum(self.rec(child) for child in expr.children) map_bitwise_xor = map_bitwise_or @@ -230,9 +245,9 @@ class ExpressionOpCounter(CombineMapper): return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) def map_min(self, expr): - return ToCountMap( - {(self.type_inf(expr), 'maxmin'): len(expr.children)-1} - ) + sum(self.rec(child) for child in expr.children) + return ToCountMap({TypedOp( + self.type_inf(expr), 'maxmin'): len(expr.children)-1} + ) + sum(self.rec(child) for child in expr.children) map_max = map_min diff --git a/test/test_statistics.py b/test/test_statistics.py index a4fc022d5697ccf00835d69b9822a66ff5bfc456..5d6fac573aa63f342f3c1d06a7edbf7ca79ae34c 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -32,7 +32,8 @@ from loopy.statistics import ( get_op_poly, get_gmem_access_poly, get_barrier_poly, - StridedGmemAccess) + StridedGmemAccess, + TypedOp) import numpy as np @@ -57,11 +58,11 @@ def test_op_counter_basic(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32add = poly[(np.dtype(np.float32), 'add')].eval_with_dict(params) - f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params) - f32div = poly[(np.dtype(np.float32), 'div')].eval_with_dict(params) - f64mul = poly[(np.dtype(np.float64), 'mul')].eval_with_dict(params) - i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params) + f32add = poly[TypedOp(np.dtype(np.float32), 'add')].eval_with_dict(params) + f32mul = poly[TypedOp(np.dtype(np.float32), 'mul')].eval_with_dict(params) + f32div = poly[TypedOp(np.dtype(np.float32), 'div')].eval_with_dict(params) + f64mul = poly[TypedOp(np.dtype(np.float64), 'mul')].eval_with_dict(params) + i32add = poly[TypedOp(np.dtype(np.int32), 'add')].eval_with_dict(params) assert f32add == f32mul == f32div == n*m*l assert f64mul == n*m assert i32add == n*m*2 @@ -82,8 +83,8 @@ def test_op_counter_reduction(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32add = poly[(np.dtype(np.float32), 'add')].eval_with_dict(params) - f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params) + f32add = poly[TypedOp(np.dtype(np.float32), 'add')].eval_with_dict(params) + f32mul = poly[TypedOp(np.dtype(np.float32), 'mul')].eval_with_dict(params) assert f32add == f32mul == n*m*l @@ -104,10 +105,10 @@ def test_op_counter_logic(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params) - f64add = poly[(np.dtype(np.float64), 'add')].eval_with_dict(params) - f64div = poly[(np.dtype(np.float64), 'div')].eval_with_dict(params) - i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params) + f32mul = poly[TypedOp(np.dtype(np.float32), 'mul')].eval_with_dict(params) + f64add = poly[TypedOp(np.dtype(np.float64), 'add')].eval_with_dict(params) + f64div = poly[TypedOp(np.dtype(np.float64), 'div')].eval_with_dict(params) + i32add = poly[TypedOp(np.dtype(np.int32), 'add')].eval_with_dict(params) assert f32mul == n*m assert f64div == 2*n*m # TODO why? assert f64add == n*m @@ -133,18 +134,18 @@ def test_op_counter_specialops(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32mul = poly[(np.dtype(np.float32), 'mul')].eval_with_dict(params) - f32div = poly[(np.dtype(np.float32), 'div')].eval_with_dict(params) - f32add = poly[(np.dtype(np.float32), 'add')].eval_with_dict(params) - f64pow = poly[(np.dtype(np.float64), 'pow')].eval_with_dict(params) - f64add = poly[(np.dtype(np.float64), 'add')].eval_with_dict(params) - i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params) - f64rsqrt = poly[(np.dtype(np.float64), 'func:rsqrt')].eval_with_dict(params) - f64sin = poly[(np.dtype(np.float64), 'func:sin')].eval_with_dict(params) + f32mul = poly[TypedOp(np.dtype(np.float32), 'mul')].eval_with_dict(params) + f32div = poly[TypedOp(np.dtype(np.float32), 'div')].eval_with_dict(params) + f32add = poly[TypedOp(np.dtype(np.float32), 'add')].eval_with_dict(params) + f64pow = poly[TypedOp(np.dtype(np.float64), 'pow')].eval_with_dict(params) + f64add = poly[TypedOp(np.dtype(np.float64), 'add')].eval_with_dict(params) + i32add = poly[TypedOp(np.dtype(np.int32), 'add')].eval_with_dict(params) + f64rsq = poly[TypedOp(np.dtype(np.float64), 'func:rsqrt')].eval_with_dict(params) + f64sin = poly[TypedOp(np.dtype(np.float64), 'func:sin')].eval_with_dict(params) assert f32div == 2*n*m*l assert f32mul == f32add == n*m*l assert f64add == 3*n*m - assert f64pow == i32add == f64rsqrt == f64sin == n*m + assert f64pow == i32add == f64rsq == f64sin == n*m def test_op_counter_bitwise(): @@ -169,12 +170,12 @@ def test_op_counter_bitwise(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - i32add = poly[(np.dtype(np.int32), 'add')].eval_with_dict(params) - i32bw = poly[(np.dtype(np.int32), 'bw')].eval_with_dict(params) - i64bw = poly[(np.dtype(np.int64), 'bw')].eval_with_dict(params) - i64mul = poly[(np.dtype(np.int64), 'mul')].eval_with_dict(params) - i64add = poly[(np.dtype(np.int64), 'add')].eval_with_dict(params) - i64shift = poly[(np.dtype(np.int64), 'shift')].eval_with_dict(params) + i32add = poly[TypedOp(np.dtype(np.int32), 'add')].eval_with_dict(params) + i32bw = poly[TypedOp(np.dtype(np.int32), 'bw')].eval_with_dict(params) + i64bw = poly[TypedOp(np.dtype(np.int64), 'bw')].eval_with_dict(params) + i64mul = poly[TypedOp(np.dtype(np.int64), 'mul')].eval_with_dict(params) + i64add = poly[TypedOp(np.dtype(np.int64), 'add')].eval_with_dict(params) + i64shift = poly[TypedOp(np.dtype(np.int64), 'shift')].eval_with_dict(params) assert i32add == n*m+n*m*l assert i32bw == 2*n*m*l assert i64bw == 2*n*m @@ -203,7 +204,7 @@ def test_op_counter_triangular_domain(): else: expect_fallback = False - poly = get_op_poly(knl)[(np.dtype(np.float64), 'mul')] + poly = get_op_poly(knl)[TypedOp(np.dtype(np.float64), 'mul')] value_dict = dict(m=13, n=200) flops = poly.eval_with_dict(value_dict) @@ -555,16 +556,16 @@ def test_all_counters_parallel_matmul(): op_map = get_op_poly(knl) f32mul = op_map[ - (np.dtype(np.float32), 'mul') + TypedOp(np.dtype(np.float32), 'mul') ].eval_with_dict(params) f32add = op_map[ - (np.dtype(np.float32), 'add') + TypedOp(np.dtype(np.float32), 'add') ].eval_with_dict(params) i32ops = op_map[ - (np.dtype(np.int32), 'add') + TypedOp(np.dtype(np.int32), 'add') ].eval_with_dict(params) i32ops += op_map[ - (np.dtype(np.int32), 'mul') + TypedOp(np.dtype(np.int32), 'mul') ].eval_with_dict(params) assert f32mul+f32add == n*m*l*2