From e16469d799a21f7cc5459a445c64bc05a21209cc Mon Sep 17 00:00:00 2001 From: James Stevens <jdsteve2@porter.cs.illinois.edu> Date: Fri, 1 Apr 2016 17:35:36 -0500 Subject: [PATCH] adding local mem access counter --- loopy/statistics.py | 188 ++++++++++++++++++++++++++++++++++++++++ test/test_statistics.py | 35 ++++---- 2 files changed, 204 insertions(+), 19 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 3f2c3a4b5..419eb2868 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -120,6 +120,24 @@ class TypedOp: return hash(str(self.dtype)+self.name) +class LmemAccess: + + def __init__(self, dtype, direction=None): + self.dtype = dtype + self.direction = direction + + def __eq__(self, other): + return isinstance(other, LmemAccess) and ( + other.dtype == self.dtype and + other.direction == self.direction) + + def __hash__(self): + direction = self.direction + if direction == None: + direction = 'None' + return hash(str(self.dtype)+direction) + + class StridedGmemAccess: #TODO "ANY_VAR" does not work yet @@ -279,6 +297,116 @@ class ExpressionOpCounter(CombineMapper): # }}} +# {{{ LocalSubscriptCounter + +class LocalSubscriptCounter(CombineMapper): + + def __init__(self, knl): + self.knl = knl + from loopy.expression import TypeInferenceMapper + self.type_inf = TypeInferenceMapper(knl) + + def combine(self, values): + return sum(values) + + def map_constant(self, expr): + return ToCountMap() + + map_tagged_variable = map_constant + map_variable = map_constant + + def map_call(self, expr): + return self.rec(expr.parameters) + + def map_subscript(self, expr): + name = expr.aggregate.name # name of array + + if name in self.knl.temporary_variables: + array = self.knl.temporary_variables[name] + #print("array: ", array) + #print("is local? ", array.is_local) + if array.is_local: + return ToCountMap( + {LmemAccess(self.type_inf(expr), direction=None): 1} + ) + self.rec(expr.index) + + return self.rec(expr.index) + + def map_sum(self, expr): + if expr.children: + return sum(self.rec(child) for child in expr.children) + else: + return ToCountMap() + + map_product = map_sum + + def map_quotient(self, expr, *args): + return self.rec(expr.numerator) + self.rec(expr.denominator) + + map_floor_div = map_quotient + map_remainder = map_quotient + + def map_power(self, expr): + return self.rec(expr.base) + self.rec(expr.exponent) + + def map_left_shift(self, expr): + return self.rec(expr.shiftee)+self.rec(expr.shift) + + map_right_shift = map_left_shift + + def map_bitwise_not(self, expr): + return self.rec(expr.child) + + def map_bitwise_or(self, expr): + return sum(self.rec(child) for child in expr.children) + + map_bitwise_xor = map_bitwise_or + map_bitwise_and = map_bitwise_or + + def map_comparison(self, expr): + return self.rec(expr.left)+self.rec(expr.right) + + map_logical_not = map_bitwise_not + map_logical_or = map_bitwise_or + map_logical_and = map_logical_or + + def map_if(self, expr): + warnings.warn("LocalSubscriptCounter counting LMEM accesses as " + "sum of if-statement branches.") + return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_) + + def map_if_positive(self, expr): + warnings.warn("LocalSubscriptCounter counting LMEM accesses as " + "sum of if_pos-statement branches.") + return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) + + map_min = map_bitwise_or + map_max = map_min + + def map_common_subexpression(self, expr): + raise NotImplementedError("LocalSubscriptCounter encountered " + "common_subexpression, " + "map_common_subexpression not implemented.") + + def map_substitution(self, expr): + raise NotImplementedError("LocalSubscriptCounter encountered " + "substitution, " + "map_substitution not implemented.") + + def map_derivative(self, expr): + raise NotImplementedError("LocalSubscriptCounter encountered " + "derivative, " + "map_derivative not implemented.") + + def map_slice(self, expr): + raise NotImplementedError("LocalSubscriptCounter encountered slice, " + "map_slice not implemented.") + +# }}} + + + + # {{{ GlobalSubscriptCounter class GlobalSubscriptCounter(CombineMapper): @@ -674,6 +802,66 @@ def sum_ops_to_dtypes(op_poly_dict): return result +def get_lmem_access_poly(knl): + + """Count the number of local memory accesses in a loopy kernel. + """ + + from loopy.preprocess import preprocess_kernel, infer_unknown_types + + class CacheHolder(object): + pass + + cache_holder = CacheHolder() + + @memoize_in(cache_holder, "insn_count") + def get_insn_count(knl, insn_inames): + inames_domain = knl.get_inames_domain(insn_inames) + domain = (inames_domain.project_out_except( + insn_inames, [dim_type.set])) + return count(knl, domain) + + knl = infer_unknown_types(knl, expect_completion=True) + knl = preprocess_kernel(knl) + + subs_poly = ToCountMap() + subscript_counter = LocalSubscriptCounter(knl) + for insn in knl.instructions: + # count subscripts, distinguishing loads and stores + subs_expr = subscript_counter(insn.expression) + for key in subs_expr.dict: + subs_expr.dict[LmemAccess( + key.dtype, direction='load') + ] = subs_expr.dict.pop(key) + subs_assignee = subscript_counter(insn.assignee) + for key in subs_assignee.dict: + print(key.dtype, key.direction, subs_assignee.dict[key]) + + # for now, not counting stores in local mem + ''' + for key in subs_assignee.dict: + subs_assignee.dict[LmemAccess( + key.dtype, direction='store') + ] = subs_assignee.dict.pop(key) + ''' + + insn_inames = knl.insn_inames(insn) + + # use count excluding local index tags for uniform accesses + for key in subs_expr.dict: + poly = ToCountMap({key: subs_expr.dict[key]}) + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames) + + # for now, not counting stores in local mem + ''' + for key in subs_assignee.dict: + poly = ToCountMap({key: subs_assignee.dict[key]}) + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames) + ''' + + return subs_poly.dict + + # {{{ get_gmem_access_poly def get_gmem_access_poly(knl): # for now just counting subscripts diff --git a/test/test_statistics.py b/test/test_statistics.py index 56b9f4003..0353ac0d4 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -31,8 +31,10 @@ import loopy as lp from loopy.statistics import ( get_op_poly, get_gmem_access_poly, + get_lmem_access_poly, get_barrier_poly, StridedGmemAccess, + LmemAccess, TypedOp) import numpy as np @@ -578,6 +580,9 @@ def test_all_counters_parallel_matmul(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) knl = lp.split_iname(knl, "i", 16, outer_tag="g.0", inner_tag="l.1") knl = lp.split_iname(knl, "j", 16, outer_tag="g.1", inner_tag="l.0") + knl = lp.split_iname(knl, "k", 16) + knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"]) + knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"]) n = 512 m = 256 @@ -585,7 +590,7 @@ def test_all_counters_parallel_matmul(): params = {'n': n, 'm': m, 'l': l} barrier_count = get_barrier_poly(knl).eval_with_dict(params) - assert barrier_count == 0 + assert barrier_count == 2*m/16 op_map = get_op_poly(knl) f32mul = op_map[ @@ -602,35 +607,27 @@ def test_all_counters_parallel_matmul(): ].eval_with_dict(params) assert f32mul+f32add == n*m*l*2 - assert i32ops == n*m*l*4 + l*n*4 subscript_map = get_gmem_access_poly(knl) - #f32uncoal = subscript_map[StridedGmemAccess( - # np.dtype(np.float32), Variable('m'), direction='load', variable='ANY_VAR') - # ].eval_with_dict(params) - #test = StridedGmemAccess(np.dtype(np.float32), sys.maxsize, direction='load', variable='ANY_VAR') - #print("test key: ", test.dtype, test.stride, test.direction, test.variable) - #for key in subscript_map: - # print(key.dtype, key.stride, key.direction, key.variable) - f32uncoal = subscript_map[StridedGmemAccess( - np.dtype(np.float32), sys.maxsize, direction='load', variable='a') - ].eval_with_dict(params) - ''' - f32uncoal = subscript_map[StridedGmemAccess( - np.dtype(np.float32), sys.maxsize, direction='load', variable='ANY_VAR') - ].eval_with_dict(params) - ''' + f32coal = subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, direction='load', variable='b') ].eval_with_dict(params) + f32coal += subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, direction='load', variable='a') + ].eval_with_dict(params) - assert f32uncoal == n*m*l - assert f32coal == n*m*l + assert f32coal == n*m+m*l f32coal = subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, direction='store', variable='c') ].eval_with_dict(params) assert f32coal == n*l + local_subs_map = get_lmem_access_poly(knl) + + local_subs_l = local_subs_map[LmemAccess(np.dtype(np.float32), direction='load') + ].eval_with_dict(params) + + assert local_subs_l == n*m*l*2 def test_gather_access_footprint(): knl = lp.make_kernel( -- GitLab