diff --git a/loopy/statistics.py b/loopy/statistics.py index d226572487e0c483bde927b3d48a3aab46f65322..15b0605ecece5e77751827519a1aca4a13a9b2b0 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -31,7 +31,7 @@ import islpy as isl from pymbolic.mapper import CombineMapper -class TypeToOpCountMap: +class TypeToCountMap: def __init__(self, init_dict=None): if init_dict is None: @@ -45,24 +45,24 @@ class TypeToOpCountMap: for k, v in six.iteritems(other.dict): result[k] = self.dict.get(k, 0) + v - return TypeToOpCountMap(result) + return TypeToCountMap(result) def __radd__(self, other): if other != 0: - raise ValueError("TypeToOpCountMap: Attempted to add TypeToOpCountMap " - "to {} {}. TypeToOpCountMap may only be added to " - "0 and other TypeToOpCountMap objects." + raise ValueError("TypeToCountMap: Attempted to add TypeToCountMap " + "to {} {}. TypeToCountMap may only be added to " + "0 and other TypeToCountMap objects." .format(type(other), other)) return return self def __mul__(self, other): if isinstance(other, isl.PwQPolynomial): - return TypeToOpCountMap({index: self.dict[index]*other + return TypeToCountMap({index: self.dict[index]*other for index in self.dict.keys()}) else: - raise ValueError("TypeToOpCountMap: Attempted to multiply " - "TypeToOpCountMap by {} {}." + raise ValueError("TypeToCountMap: Attempted to multiply " + "TypeToCountMap by {} {}." .format(type(other), other)) __rmul__ = __mul__ @@ -88,7 +88,7 @@ class ExpressionOpCounter(CombineMapper): return sum(values) def map_constant(self, expr): - return TypeToOpCountMap() + return TypeToCountMap() map_tagged_variable = map_constant map_variable = map_constant @@ -110,16 +110,16 @@ class ExpressionOpCounter(CombineMapper): def map_sum(self, expr): if expr.children: - return TypeToOpCountMap( + return TypeToCountMap( {self.type_inf(expr): len(expr.children)-1} ) + sum(self.rec(child) for child in expr.children) else: - return TypeToOpCountMap() + return TypeToCountMap() map_product = map_sum def map_quotient(self, expr, *args): - return TypeToOpCountMap({self.type_inf(expr): 1}) \ + return TypeToCountMap({self.type_inf(expr): 1}) \ + self.rec(expr.numerator) \ + self.rec(expr.denominator) @@ -127,21 +127,26 @@ class ExpressionOpCounter(CombineMapper): map_remainder = map_quotient # implemented in CombineMapper def map_power(self, expr): - return TypeToOpCountMap({self.type_inf(expr): 1}) \ + return TypeToCountMap({self.type_inf(expr): 1}) \ + self.rec(expr.base) \ + self.rec(expr.exponent) def map_left_shift(self, expr): # implemented in CombineMapper - return self.rec(expr.shiftee)+self.rec(expr.shift) + return TypeToCountMap({self.type_inf(expr): 1}) \ + + self.rec(expr.shiftee) \ + + self.rec(expr.shift) map_right_shift = map_left_shift def map_bitwise_not(self, expr): # implemented in CombineMapper - return self.rec(expr.child) + return TypeToCountMap({self.type_inf(expr): 1}) \ + + self.rec(expr.child) def map_bitwise_or(self, expr): # implemented in CombineMapper, maps to map_sum; - return sum(self.rec(child) for child in expr.children) + return TypeToCountMap( + {self.type_inf(expr): len(expr.children)-1} + ) + sum(self.rec(child) for child in expr.children) map_bitwise_xor = map_bitwise_or # implemented in CombineMapper, maps to map_sum; @@ -152,16 +157,22 @@ class ExpressionOpCounter(CombineMapper): def map_comparison(self, expr): # implemented in CombineMapper return self.rec(expr.left)+self.rec(expr.right) - map_logical_not = map_bitwise_not - map_logical_or = map_bitwise_or # implemented in CombineMapper, maps to map_sum + def map_logical_not(self, expr): + return self.rec(expr.child) + + def map_logical_or(self, expr): + return sum(self.rec(child) for child in expr.children) + map_logical_and = map_logical_or def map_if(self, expr): # implemented in CombineMapper, recurses - warnings.warn("Counting operations as sum of if-statement branches.") + warnings.warn("ExpressionOpCounter counting DRAM accesses as " + "sum of if-statement branches.") return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_) def map_if_positive(self, expr): # implemented in FlopCounter - warnings.warn("Counting operations as sum of if_pos-statement branches.") + warnings.warn("ExpressionOpCounter counting DRAM accesses as " + "sum of if_pos-statement branches.") return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) map_min = map_bitwise_or @@ -170,51 +181,201 @@ class ExpressionOpCounter(CombineMapper): map_max = map_min # implemented in CombineMapper, maps to map_sum; # TODO test def map_common_subexpression(self, expr): - raise NotImplementedError("OpCounter encountered common_subexpression, " + raise NotImplementedError("ExpressionOpCounter encountered " + "common_subexpression, " "map_common_subexpression not implemented.") return 0 def map_substitution(self, expr): - raise NotImplementedError("OpCounter encountered substitution, " + raise NotImplementedError("ExpressionOpCounter encountered substitution, " "map_substitution not implemented.") return 0 def map_derivative(self, expr): - raise NotImplementedError("OpCounter encountered derivative, " + raise NotImplementedError("ExpressionOpCounter encountered derivative, " "map_derivative not implemented.") return 0 def map_slice(self, expr): - raise NotImplementedError("OpCounter encountered slice, " + raise NotImplementedError("ExpressionOpCounter encountered slice, " "map_slice not implemented.") return 0 -class SubscriptCounter(CombineMapper): - def __init__(self, kernel): - self.kernel = kernel +class ExpressionSubscriptCounter(CombineMapper): + + def __init__(self, knl): + self.knl = knl + from loopy.expression import TypeInferenceMapper + self.type_inf = TypeInferenceMapper(knl) def combine(self, values): return sum(values) + def map_constant(self, expr): + return TypeToCountMap() + + map_tagged_variable = map_constant + map_variable = map_constant + map_call = map_constant + def map_subscript(self, expr): - name = expr.aggregate.name - arg = self.kernel.arg_dict.get(name) - tv = self.kernel.temporary_variables.get(name) - if arg is not None: - if isinstance(arg, lp.GlobalArg): - # It's global memory - pass - elif tv is not None: - if tv.is_local: - # It's shared memory - pass - return 1 + self.rec(expr.index) + name = expr.aggregate.name # name of array - def map_constant(self, expr): + if name in self.knl.arg_dict: + array = self.knl.arg_dict[name] + else: + # this is a temporary variable + return self.rec(expr.index) + + if not isinstance(array, lp.GlobalArg): + # this array is not in global memory + return self.rec(expr.index) + + index = expr.index # could be tuple or scalar index + if not isinstance(index, tuple): + index = (index,) + + from loopy.symbolic import get_dependencies + from loopy.kernel.data import LocalIndexTag + my_inames = get_dependencies(index) & self.knl.all_inames() + local_id0 = None + local_id_found = False + for iname in my_inames: + # find local id0 + tag = self.knl.iname_to_tag.get(iname) + if isinstance(tag, LocalIndexTag): + local_id_found = True + if tag.axis == 0: + local_id0 = iname + break # there will be only one local_id0 + + if not local_id_found: + # count as uniform access + warnings.warn("ExpressionSubscriptCounter did not find " + "local iname tags in expression:\n %s,\n" + "considering these DRAM accesses uniform." % expr) + return TypeToCountMap( + {(self.type_inf(expr), 'uniform'): 1} + ) + self.rec(expr.index) + + if local_id0 is None: + # only non-zero local id(s) found, assume non-consecutive access + return TypeToCountMap( + {(self.type_inf(expr), 'nonconsecutive'): 1} + ) + self.rec(expr.index) + + # check coefficient of local_id0 for each axis + from loopy.symbolic import CoefficientCollector + from pymbolic.primitives import Variable + for idx, axis_tag in zip(index, array.dim_tags): + + coeffs = CoefficientCollector()(idx) + # check if he contains the lid 0 guy + try: + coeff_id0 = coeffs[Variable(local_id0)] + except KeyError: + # does not contain local_id0 + continue + + if coeff_id0 != 1: + # non-consecutive access + return TypeToCountMap( + {(self.type_inf(expr), 'nonconsecutive'): 1} + ) + self.rec(expr.index) + + # coefficient is 1, now determine if stride is 1 + from loopy.kernel.array import FixedStrideArrayDimTag + if isinstance(axis_tag, FixedStrideArrayDimTag): + stride = axis_tag.stride + else: + continue + + if stride != 1: + # non-consecutive + return TypeToCountMap( + {(self.type_inf(expr), 'nonconsecutive'): 1} + ) + self.rec(expr.index) + + # else, stride == 1, continue since another idx could contain id0 + + # loop finished without returning, stride==1 for every instance of local_id0 + return TypeToCountMap( + {(self.type_inf(expr), 'consecutive'): 1} + ) + self.rec(expr.index) + + def map_sum(self, expr): + if expr.children: + return sum(self.rec(child) for child in expr.children) + else: + return TypeToCountMap() + + map_product = map_sum + + def map_quotient(self, expr, *args): + return self.rec(expr.numerator) + self.rec(expr.denominator) + + map_floor_div = map_quotient + map_remainder = map_quotient + + def map_power(self, expr): + return self.rec(expr.base) + self.rec(expr.exponent) + + def map_left_shift(self, expr): + return self.rec(expr.shiftee)+self.rec(expr.shift) + + map_right_shift = map_left_shift + + def map_bitwise_not(self, expr): + return self.rec(expr.child) + + def map_bitwise_or(self, expr): + return sum(self.rec(child) for child in expr.children) + + map_bitwise_xor = map_bitwise_or + map_bitwise_and = map_bitwise_or + + def map_comparison(self, expr): + return self.rec(expr.left)+self.rec(expr.right) + + map_logical_not = map_bitwise_not + map_logical_or = map_bitwise_or + map_logical_and = map_logical_or + + def map_if(self, expr): + warnings.warn("ExpressionSubscriptCounter counting DRAM accesses as " + "sum of if-statement branches.") + return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_) + + def map_if_positive(self, expr): + warnings.warn("ExpressionSubscriptCounter counting DRAM accesses as " + "sum of if_pos-statement branches.") + return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) + + map_min = map_bitwise_or + map_max = map_min + + def map_common_subexpression(self, expr): + raise NotImplementedError("ExpressionSubscriptCounter encountered " + "common_subexpression, " + "map_common_subexpression not implemented.") return 0 - def map_variable(self, expr): + def map_substitution(self, expr): + raise NotImplementedError("ExpressionSubscriptCounter encountered " + "substitution, " + "map_substitution not implemented.") + return 0 + + def map_derivative(self, expr): + raise NotImplementedError("ExpressionSubscriptCounter encountered " + "derivative, " + "map_derivative not implemented.") + return 0 + + def map_slice(self, expr): + raise NotImplementedError("ExpressionSubscriptCounter encountered slice, " + "map_slice not implemented.") return 0 @@ -268,12 +429,47 @@ def get_op_poly(knl): def get_DRAM_access_poly(knl): # for now just counting subscripts - raise NotImplementedError("get_DRAM_access_poly not yet implemented.") - poly = 0 - subscript_counter = SubscriptCounter(knl) + from loopy.preprocess import preprocess_kernel, infer_unknown_types + knl = infer_unknown_types(knl, expect_completion=True) + knl = preprocess_kernel(knl) + + subs_poly = 0 + subscript_counter = ExpressionSubscriptCounter(knl) for insn in knl.instructions: insn_inames = knl.insn_inames(insn) inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) - poly += subscript_counter(insn.expression) * count(knl, domain) - return poly + subs = subscript_counter(insn.expression) + subs_poly = subs_poly + subs*count(knl, domain) + return subs_poly + + +def get_barrier_poly(knl): + from loopy.preprocess import preprocess_kernel, infer_unknown_types + from loopy.schedule import EnterLoop, LeaveLoop, Barrier + from operator import mul + knl = infer_unknown_types(knl, expect_completion=True) + knl = preprocess_kernel(knl) + knl = lp.get_one_scheduled_kernel(knl) + iname_list = [] + barrier_poly = isl.PwQPolynomial('{ 0 }') # 0 + + for sched_item in knl.schedule: + if isinstance(sched_item, EnterLoop): + if sched_item.iname: # (if not empty) + iname_list.append(sched_item.iname) + elif isinstance(sched_item, LeaveLoop): + if sched_item.iname: # (if not empty) + iname_list.pop() + elif isinstance(sched_item, Barrier): + if iname_list: # (if iname_list is not empty) + ct = (count(knl, ( + knl.get_inames_domain(iname_list). + project_out_except(iname_list, [dim_type.set]) + )), ) + barrier_poly += reduce(mul, ct) + else: + barrier_poly += isl.PwQPolynomial('{ 1 }') + + return barrier_poly + diff --git a/test/test_statistics.py b/test/test_statistics.py index 6867ca28f696b4e9de827ecb340698ada6e8cfa9..dc040864f4a0affe4f0356008d1f5ea46450f471 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -27,7 +27,7 @@ from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) import loopy as lp -from loopy.statistics import get_op_poly # noqa +from loopy.statistics import get_op_poly, get_DRAM_access_poly, get_barrier_poly import numpy as np @@ -41,7 +41,7 @@ def test_op_counter_basic(): e[i, k] = g[i,k]*h[i,k+1] """ ], - name="weird", assumptions="n,m,l >= 1") + name="basic", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) @@ -64,7 +64,7 @@ def test_op_counter_reduction(): [ "c[i, j] = sum(k, a[i, k]*b[k, j])" ], - name="matmul", assumptions="n,m,l >= 1") + name="matmul_serial", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) poly = get_op_poly(knl) @@ -143,18 +143,15 @@ def test_op_counter_bitwise(): g=np.int64, h=np.int64)) poly = get_op_poly(knl) - - n = 10 - m = 10 - l = 10 - param_values = {'n': n, 'm': m, 'l': l} - i32 = poly.dict[np.dtype(np.int32)].eval_with_dict(param_values) - i64 = poly.dict[np.dtype(np.int64)].eval_with_dict(param_values) - not_there = poly[np.dtype(np.float64)].eval_with_dict(param_values) - print(poly.dict) - assert i32 == n*m + n*m*l - assert i64 == 2*n*m - assert not_there == 0 + n = 512 + m = 256 + l = 128 + i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l}) + i64 = poly.dict[np.dtype(np.int64)].eval_with_dict({'n': n, 'm': m, 'l': l}) # noqa + f64 = poly[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l}) + assert i32 == n*m+3*n*m*l + assert i64 == 6*n*m + assert f64 == 0 def test_op_counter_triangular_domain(): @@ -188,9 +185,317 @@ def test_op_counter_triangular_domain(): assert flops == 78 +def test_DRAM_access_counter_basic(): + + knl = lp.make_kernel( + "[n,m,l] -> {[i,k,j]: 0<=i6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2) + """ + ], + name="logic", assumptions="n,m,l >= 1") + + knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) + poly = get_DRAM_access_poly(knl) + n = 512 + m = 256 + l = 128 + f32 = poly.dict[ + (np.dtype(np.float32), 'uniform') + ].eval_with_dict({'n': n, 'm': m, 'l': l}) + f64 = poly.dict[ + (np.dtype(np.float64), 'uniform') + ].eval_with_dict({'n': n, 'm': m, 'l': l}) + assert f32 == 2*n*m + assert f64 == n*m + + +def test_DRAM_access_counter_specialops(): + + knl = lp.make_kernel( + "{[i,k,j]: 0<=i> k)) + """ + ], + name="bitwise", assumptions="n,m,l >= 1") + + knl = lp.add_and_infer_dtypes( + knl, dict( + a=np.int32, b=np.int32, + g=np.int32, h=np.int32)) + + poly = get_DRAM_access_poly(knl) + n = 512 + m = 256 + l = 128 + i32 = poly.dict[ + (np.dtype(np.int32), 'uniform') + ].eval_with_dict({'n': n, 'm': m, 'l': l}) + assert i32 == 4*n*m+2*n*m*l + + +def test_DRAM_access_counter_mixed(): + + knl = lp.make_kernel( + "[n,m,l] -> {[i,k,j]: 0<=i {[i,k,j]: 0<=i {[i,k,j]: 0<=i {[i,k,j]: 0<=i {[i,k,j]: 0<=i<50 and 1<=k<98 and 0<=j<10}", + [ + """ + c[i,j,k] = 2*a[i,j,k] {id=first} + e[i,j,k] = c[i,j,k+1]+c[i,j,k-1] {dep=first} + """ + ], [ + lp.TemporaryVariable("c", lp.auto, shape=(50, 10, 99)), + "..." + ], + name="weird2", + ) + knl = lp.add_and_infer_dtypes(knl, dict(a=np.int32)) + knl = lp.split_iname(knl, "k", 128, outer_tag="g.0", inner_tag="l.0") + poly = get_barrier_poly(knl) + n = 512 + m = 256 + l = 128 + barrier_count = poly.eval_with_dict({'n': n, 'm': m, 'l': l}) + assert barrier_count == 50*10*2 + + +def test_all_counters_parallel_matmul(): + + knl = lp.make_kernel( + "{[i,k,j]: 0<=i 1: exec(sys.argv[1]) else: from py.test.cmdline import main main([__file__]) +