diff --git a/loopy/statistics.py b/loopy/statistics.py index 24ce751ff5bd79308ac54f7c00875293eed0aa4e..efe2f1d0eabb1133dc29d2a2b7592b95f3a260be 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -135,16 +135,21 @@ class ExpressionOpCounter(CombineMapper): + self.rec(expr.exponent) def map_left_shift(self, expr): # implemented in CombineMapper - return self.rec(expr.shiftee)+self.rec(expr.shift) + return TypeToOpCountMap({self.type_inf(expr.shiftee): 1}) \ + + self.rec(expr.shiftee) \ + + self.rec(expr.shift) map_right_shift = map_left_shift def map_bitwise_not(self, expr): # implemented in CombineMapper - return self.rec(expr.child) + return TypeToOpCountMap({self.type_inf(expr.child): 1}) \ + + self.rec(expr.child) def map_bitwise_or(self, expr): # implemented in CombineMapper, maps to map_sum; - return sum(self.rec(child) for child in expr.children) + return TypeToOpCountMap( + {self.type_inf(expr): len(expr.children)-1} + ) + sum(self.rec(child) for child in expr.children) map_bitwise_xor = map_bitwise_or # implemented in CombineMapper, maps to map_sum; @@ -155,16 +160,22 @@ class ExpressionOpCounter(CombineMapper): def map_comparison(self, expr): # implemented in CombineMapper return self.rec(expr.left)+self.rec(expr.right) - map_logical_not = map_bitwise_not - map_logical_or = map_bitwise_or # implemented in CombineMapper, maps to map_sum + def map_logical_not(self, expr): + return self.rec(expr.child) + + def map_logical_or(self, expr): + return sum(self.rec(child) for child in expr.children) + map_logical_and = map_logical_or def map_if(self, expr): # implemented in CombineMapper, recurses - warnings.warn("Counting operations as sum of if-statement branches.") + warnings.warn("OpCounter counting DRAM accesses as " + "sum of if-statement branches.") return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_) def map_if_positive(self, expr): # implemented in FlopCounter - warnings.warn("Counting operations as sum of if_pos-statement branches.") + warnings.warn("OpCounter counting DRAM accesses as " + "sum of if_pos-statement branches.") return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) map_min = map_bitwise_or @@ -194,16 +205,25 @@ class ExpressionOpCounter(CombineMapper): class SubscriptCounter(CombineMapper): - def __init__(self, kernel): - self.kernel = kernel + def __init__(self, knl): + self.knl = knl + from loopy.expression import TypeInferenceMapper + self.type_inf = TypeInferenceMapper(knl) def combine(self, values): return sum(values) + def map_constant(self, expr): + return TypeToOpCountMap() + + map_tagged_variable = map_constant + map_variable = map_constant + map_call = map_constant + def map_subscript(self, expr): name = expr.aggregate.name - arg = self.kernel.arg_dict.get(name) - tv = self.kernel.temporary_variables.get(name) + arg = self.knl.arg_dict.get(name) + tv = self.knl.temporary_variables.get(name) if arg is not None: if isinstance(arg, lp.GlobalArg): # It's global memory @@ -212,16 +232,12 @@ class SubscriptCounter(CombineMapper): if tv.is_local: # It's shared memory pass - return 1 + self.rec(expr.index) + #return 1 + self.rec(expr.index) + return TypeToOpCountMap( + {self.type_inf(expr): 1} + ) + self.rec(expr.index) - def map_constant(self, expr): - return 0 - - def map_variable(self, expr): - return 0 - -''' -class AccessCounter(CombineMapper): + ''' def map_subscript(self, expr): name = expr.aggregate.name if name in self.kernel.arg_dict: @@ -253,17 +269,89 @@ class AccessCounter(CombineMapper): # determine if stride 1 # find coefficient -''' + ''' + + #TODO find stride looking in ArrayBase.dim tag + ''' + for each instruction, find which iname is associated with local id0 (iname_to_tag) + then for each array axis in that instruction, run through all axes and see if local id0 iname occurs + for each axis where this occurs, see if stride=1 (using coefficient collecter) + + variable has dimTags (one for each axis), + localid 0 is threadidx.x + ''' + + def map_sum(self, expr): + if expr.children: + return sum(self.rec(child) for child in expr.children) + else: + return TypeToOpCountMap() + + map_product = map_sum + + def map_quotient(self, expr, *args): + return self.rec(expr.numerator) + self.rec(expr.denominator) + + map_floor_div = map_quotient + map_remainder = map_quotient + + def map_power(self, expr): + return self.rec(expr.base) + self.rec(expr.exponent) -#TODO find stride looking in ArrayBase.dim tag -''' -for each instruction, find which iname is associated with local id0 (iname_to_tag) -then for each array axis in that instruction, run through all axes and see if local id0 iname occurs -for each axis where this occurs, see if stride=1 (using coefficient collecter) + def map_left_shift(self, expr): + return self.rec(expr.shiftee)+self.rec(expr.shift) + + map_right_shift = map_left_shift -variable has dimTags (one for each axis), -localid 0 is threadidx.x -''' + def map_bitwise_not(self, expr): + return self.rec(expr.child) + + def map_bitwise_or(self, expr): + return sum(self.rec(child) for child in expr.children) + + map_bitwise_xor = map_bitwise_or + map_bitwise_and = map_bitwise_or + + def map_comparison(self, expr): + return self.rec(expr.left)+self.rec(expr.right) + + map_logical_not = map_bitwise_not + map_logical_or = map_bitwise_or + map_logical_and = map_logical_or + + def map_if(self, expr): + warnings.warn("SubscriptCounter counting DRAM accesses as " + "sum of if-statement branches.") + return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_) + + def map_if_positive(self, expr): + warnings.warn("SubscriptCounter counting DRAM accesses as " + "sum of if_pos-statement branches.") + return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) + + map_min = map_bitwise_or + map_max = map_min + + def map_common_subexpression(self, expr): + raise NotImplementedError("SubscriptCounter encountered " + "common_subexpression, " + "map_common_subexpression not implemented.") + return 0 + + def map_substitution(self, expr): + raise NotImplementedError("SubscriptCounter encountered substitution, " + "map_substitution not implemented.") + return 0 + + def map_derivative(self, expr): + raise NotImplementedError("SubscriptCounter encountered derivative, " + "map_derivative not implemented.") + return 0 + + def map_slice(self, expr): + raise NotImplementedError("SubscriptCounter encountered slice, " + "map_slice not implemented.") + return 0 def count(kernel, bset): @@ -316,12 +404,17 @@ def get_op_poly(knl): def get_DRAM_access_poly(knl): # for now just counting subscripts - raise NotImplementedError("get_DRAM_access_poly not yet implemented.") - poly = 0 + # raise NotImplementedError("get_DRAM_access_poly not yet implemented.") + from loopy.preprocess import preprocess_kernel, infer_unknown_types + knl = infer_unknown_types(knl, expect_completion=True) + knl = preprocess_kernel(knl) + + subs_poly = 0 subscript_counter = SubscriptCounter(knl) for insn in knl.instructions: insn_inames = knl.insn_inames(insn) inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) - poly += subscript_counter(insn.expression) * count(knl, domain) - return poly + subs = subscript_counter(insn.expression) + subs_poly = subs_poly + subs*count(knl, domain) + return subs_poly diff --git a/test/test_statistics.py b/test/test_statistics.py index bfcd199f17dd4fa9c176145ebc4f84d1a848b987..29cb5f98ed273a178a0dd4cc927b952ec41f8d76 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -131,21 +131,27 @@ def test_op_counter_bitwise(): [ """ c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1) - e[i, k] = (g[i,k] ^ k)*(~h[i,k+1]) + (g[i, k] << (h[i,k] >> k)) + e[i, k] = (g[i,k] ^ k)*~(h[i,k+1]) + (g[i, k] << (h[i,k] >> k)) """ ], name="bitwise", assumptions="n,m,l >= 1") - knl = lp.add_and_infer_dtypes(knl, - dict(a=np.int32, b=np.int32, g=np.int64, h=np.int64)) + knl = lp.add_and_infer_dtypes( + knl, dict( + a=np.int32, b=np.int32, + g=np.int64, h=np.int64)) poly = get_op_poly(knl) - n = 512 - m = 256 - l = 128 - i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l}) - print(poly.dict[np.dtype(np.int32)]) - not_there = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l}) - assert i32 == 3*n*m+n*m*l + + n = 10 + m = 10 + l = 10 + param_values = {'n': n, 'm': m, 'l': l} + i32 = poly.dict[np.dtype(np.int32)].eval_with_dict(param_values) + i64 = poly.dict[np.dtype(np.int64)].eval_with_dict(param_values) + not_there = poly[np.dtype(np.float64)].eval_with_dict(param_values) + print(poly.dict) + assert i32 == n*m + n*m*l + assert i64 == 2*n*m assert not_there == 0