diff --git a/loopy/statistics.py b/loopy/statistics.py index eff571668775ff923443ea93af39d02f8795e115..96bf511b59948f4e56b22b078ff2d08c079bb498 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -105,16 +105,24 @@ def stringify_stats_mapping(m): return result -class DataAccess: +class StridedGmemAccess: - def __init__(self, stride=0): + def __init__(self, dtype, stride, direction=None): + self.dtype = dtype self.stride = stride + self.direction = direction def __eq__(self, other): - return isinstance(other, DataAccess) and other.stride == self.stride #TODO is this okay? + return isinstance(other, StridedGmemAccess) and ( + other.dtype == self.dtype and + other.stride == self.stride and + other.direction == self.direction ) def __hash__(self): - return hash(self.stride) + if self.direction == None: + return hash(str(self.dtype)+str(self.stride)+"None") + else: + return hash(str(self.dtype)+str(self.stride)+self.direction) # {{{ ExpressionOpCounter @@ -292,7 +300,6 @@ class GlobalSubscriptCounter(CombineMapper): # find min tag axis import sys - #local_id0 = None min_tag_axis = sys.maxsize local_id_found = False for iname in my_inames: @@ -301,14 +308,11 @@ class GlobalSubscriptCounter(CombineMapper): local_id_found = True if tag.axis < min_tag_axis: min_tag_axis = tag.axis - #if tag.axis == 0: - # local_id0 = iname - # break if not local_id_found: # count as uniform access return ToCountMap( - {(self.type_inf(expr), DataAccess(stride=0)): 1} + {StridedGmemAccess(self.type_inf(expr), 0): 1} ) + self.rec(expr.index) # get local_id associated with minimum tag axis @@ -326,41 +330,27 @@ class GlobalSubscriptCounter(CombineMapper): # check coefficient of local_id0 for each axis from loopy.symbolic import CoefficientCollector from pymbolic.primitives import Variable - #print("==========================================================================================") - #print("expr: ", expr) - #print("min_lid: ", min_lid) - #print("min_tag_axis: ", min_tag_axis) - #print("Var(min_lid): ", Variable(min_lid)) for idx, axis_tag in zip(index, array.dim_tags): - #print("...........................................................................................") - #print("idx, axis_tag: ", idx, "\t", axis_tag) coeffs = CoefficientCollector()(idx) - #print("coeffs: ", coeffs) # check if he contains the min lid guy try: coeff_min_lid = coeffs[Variable(min_lid)] except KeyError: # does not contain min_lid - #print("key error") continue - #print("coeff_min_lid: ", coeff_min_lid) - #print("axis_tag: ", axis_tag) # found coefficient of min_lid # now determine stride from loopy.kernel.array import FixedStrideArrayDimTag if isinstance(axis_tag, FixedStrideArrayDimTag): stride = axis_tag.stride else: - #print("continuing") continue - #print("stride: ", stride) total_stride = stride*coeff_min_lid #TODO is there a case where this^ does not execute, or executes more than once for two different axes? - return ToCountMap({(self.type_inf(expr), - DataAccess(stride=total_stride)): 1} - ) + self.rec(expr.index) + return ToCountMap({StridedGmemAccess(self.type_inf(expr), + total_stride): 1}) + self.rec(expr.index) def map_sum(self, expr): if expr.children: @@ -727,26 +717,28 @@ def get_gmem_access_poly(knl): # for now just counting subscripts for insn in knl.instructions: # count subscripts, distinguishing loads and stores subs_expr = subscript_counter(insn.expression) - subs_expr = ToCountMap(dict( - (key + ("load",), val) - for key, val in six.iteritems(subs_expr.dict))) + for key in subs_expr.dict: + subs_expr.dict[StridedGmemAccess( + key.dtype, key.stride, 'load') + ] = subs_expr.dict.pop(key) subs_assignee = subscript_counter(insn.assignee) - subs_assignee = ToCountMap(dict( - (key + ("store",), val) - for key, val in six.iteritems(subs_assignee.dict))) + for key in subs_assignee.dict: + subs_assignee.dict[StridedGmemAccess( + key.dtype, key.stride, 'store') + ] = subs_assignee.dict.pop(key) insn_inames = knl.insn_inames(insn) # use count excluding local index tags for uniform accesses for key in subs_expr.dict: poly = ToCountMap({key: subs_expr.dict[key]}) - if key[1].stride == 0: + if isinstance(key.stride, int) and key.stride == 0: subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, True) else: subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames) for key in subs_assignee.dict: poly = ToCountMap({key: subs_assignee.dict[key]}) - if key[1].stride == 0: + if isinstance(key.stride, int) and key.stride == 0: subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, True) else: subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames) diff --git a/test/test_statistics.py b/test/test_statistics.py index 6e5b6270be571ae7c66e2219bcafdad0d4b63efd..a4fc022d5697ccf00835d69b9822a66ff5bfc456 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -32,7 +32,7 @@ from loopy.statistics import ( get_op_poly, get_gmem_access_poly, get_barrier_poly, - DataAccess) + StridedGmemAccess) import numpy as np @@ -232,21 +232,17 @@ def test_gmem_access_counter_basic(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32 = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'load') - ].eval_with_dict(params) - f64 = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'load') - ].eval_with_dict(params) + f32 = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'load') + ].eval_with_dict(params) + f64 = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'load') + ].eval_with_dict(params) assert f32 == 3*n*m*l assert f64 == 2*n*m - f32 = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'store') - ].eval_with_dict(params) - f64 = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'store') - ].eval_with_dict(params) + f32 = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'store') + ].eval_with_dict(params) + f64 = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'store') + ].eval_with_dict(params) assert f32 == n*m*l assert f64 == n*m @@ -266,14 +262,12 @@ def test_gmem_access_counter_reduction(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32 = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'load') - ].eval_with_dict(params) + f32 = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'load') + ].eval_with_dict(params) assert f32 == 2*n*m*l - f32 = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'store') - ].eval_with_dict(params) + f32 = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'store') + ].eval_with_dict(params) assert f32 == n*l @@ -294,18 +288,15 @@ def test_gmem_access_counter_logic(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32 = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'load') - ].eval_with_dict(params) - f64 = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'load') - ].eval_with_dict(params) + f32 = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'load') + ].eval_with_dict(params) + f64 = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'load') + ].eval_with_dict(params) assert f32 == 2*n*m assert f64 == n*m - f64 = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'store') - ].eval_with_dict(params) + f64 = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'store') + ].eval_with_dict(params) assert f64 == n*m @@ -328,21 +319,17 @@ def test_gmem_access_counter_specialops(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f32 = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'load') - ].eval_with_dict(params) - f64 = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'load') - ].eval_with_dict(params) + f32 = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'load') + ].eval_with_dict(params) + f64 = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'load') + ].eval_with_dict(params) assert f32 == 2*n*m*l assert f64 == 2*n*m - f32 = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'store') - ].eval_with_dict(params) - f64 = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'store') - ].eval_with_dict(params) + f32 = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'store') + ].eval_with_dict(params) + f64 = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'store') + ].eval_with_dict(params) assert f32 == n*m*l assert f64 == n*m @@ -369,14 +356,12 @@ def test_gmem_access_counter_bitwise(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - i32 = poly[ - (np.dtype(np.int32), DataAccess(stride=0), 'load') - ].eval_with_dict(params) + i32 = poly[StridedGmemAccess(np.dtype(np.int32), 0, 'load') + ].eval_with_dict(params) assert i32 == 4*n*m+2*n*m*l - i32 = poly[ - (np.dtype(np.int32), DataAccess(stride=0), 'store') - ].eval_with_dict(params) + i32 = poly[StridedGmemAccess(np.dtype(np.int32), 0, 'store') + ].eval_with_dict(params) assert i32 == n*m+n*m*l @@ -403,24 +388,21 @@ def test_gmem_access_counter_mixed(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f64uniform = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'load') - ].eval_with_dict(params) - f32uniform = poly[ - (np.dtype(np.float32), DataAccess(stride=0), 'load') - ].eval_with_dict(params) + f64uniform = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'load') + ].eval_with_dict(params) + f32uniform = poly[StridedGmemAccess(np.dtype(np.float32), 0, 'load') + ].eval_with_dict(params) f32nonconsec = poly[ - (np.dtype(np.float32), DataAccess(stride=Variable('m')), 'load') + StridedGmemAccess(np.dtype(np.float32), Variable('m'), 'load') ].eval_with_dict(params) assert f64uniform == 2*n*m assert f32uniform == n*m*l/threads assert f32nonconsec == 3*n*m*l - f64uniform = poly[ - (np.dtype(np.float64), DataAccess(stride=0), 'store') - ].eval_with_dict(params) + f64uniform = poly[StridedGmemAccess(np.dtype(np.float64), 0, 'store') + ].eval_with_dict(params) f32nonconsec = poly[ - (np.dtype(np.float32), DataAccess(stride=Variable('m')), 'store') + StridedGmemAccess(np.dtype(np.float32), Variable('m'), 'store') ].eval_with_dict(params) assert f64uniform == n*m assert f32nonconsec == n*m*l @@ -447,21 +429,21 @@ def test_gmem_access_counter_nonconsec(): m = 256 l = 128 params = {'n': n, 'm': m, 'l': l} - f64nonconsec = poly[ - (np.dtype(np.float64), DataAccess(stride=Variable('m')), 'load') - ].eval_with_dict(params) - f32nonconsec = poly[ - (np.dtype(np.float32), DataAccess(stride=Variable('m')*Variable('l')), 'load') - ].eval_with_dict(params) + f64nonconsec = poly[StridedGmemAccess( + np.dtype(np.float64), Variable('m'), 'load') + ].eval_with_dict(params) + f32nonconsec = poly[StridedGmemAccess( + np.dtype(np.float32), Variable('m')*Variable('l'), 'load') + ].eval_with_dict(params) assert f64nonconsec == 2*n*m assert f32nonconsec == 3*n*m*l - f64nonconsec = poly[ - (np.dtype(np.float64), DataAccess(stride=Variable('m')), 'store') - ].eval_with_dict(params) - f32nonconsec = poly[ - (np.dtype(np.float32), DataAccess(stride=Variable('m')*Variable('l')), 'store') - ].eval_with_dict(params) + f64nonconsec = poly[StridedGmemAccess( + np.dtype(np.float64), Variable('m'), 'store') + ].eval_with_dict(params) + f32nonconsec = poly[StridedGmemAccess( + np.dtype(np.float32), Variable('m')*Variable('l'), 'store') + ].eval_with_dict(params) assert f64nonconsec == n*m assert f32nonconsec == n*m*l @@ -487,21 +469,17 @@ def test_gmem_access_counter_consec(): l = 128 params = {'n': n, 'm': m, 'l': l} - f64consec = poly[ - (np.dtype(np.float64), DataAccess(stride=1), 'load') - ].eval_with_dict(params) - f32consec = poly[ - (np.dtype(np.float32), DataAccess(stride=1), 'load') - ].eval_with_dict(params) + f64consec = poly[StridedGmemAccess(np.dtype(np.float64), 1, 'load') + ].eval_with_dict(params) + f32consec = poly[StridedGmemAccess(np.dtype(np.float32), 1, 'load') + ].eval_with_dict(params) assert f64consec == 2*n*m assert f32consec == 3*n*m*l - f64consec = poly[ - (np.dtype(np.float64), DataAccess(stride=1), 'store') - ].eval_with_dict(params) - f32consec = poly[ - (np.dtype(np.float32), DataAccess(stride=1), 'store') - ].eval_with_dict(params) + f64consec = poly[StridedGmemAccess(np.dtype(np.float64), 1, 'store') + ].eval_with_dict(params) + f32consec = poly[StridedGmemAccess(np.dtype(np.float32), 1, 'store') + ].eval_with_dict(params) assert f64consec == n*m assert f32consec == n*m*l @@ -593,19 +571,17 @@ def test_all_counters_parallel_matmul(): assert i32ops == n*m*l*4 + l*n*4 subscript_map = get_gmem_access_poly(knl) - f32uncoal = subscript_map[ - (np.dtype(np.float32), DataAccess(stride=Variable('m')), 'load') - ].eval_with_dict(params) - f32coal = subscript_map[ - (np.dtype(np.float32), DataAccess(stride=1), 'load') - ].eval_with_dict(params) + f32uncoal = subscript_map[StridedGmemAccess( + np.dtype(np.float32), Variable('m'), 'load') + ].eval_with_dict(params) + f32coal = subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, 'load') + ].eval_with_dict(params) assert f32uncoal == n*m*l assert f32coal == n*m*l - f32coal = subscript_map[ - (np.dtype(np.float32), DataAccess(stride=1), 'store') - ].eval_with_dict(params) + f32coal = subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, 'store') + ].eval_with_dict(params) assert f32coal == n*l