diff --git a/loopy/statistics.py b/loopy/statistics.py index 5855f08521b79fc4ecea0ba1ba0d74b563d371e3..3f2c3a4b5e25f6d1d49e43b2fc37561b4edd2c60 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -122,22 +122,30 @@ class TypedOp: class StridedGmemAccess: - def __init__(self, dtype, stride, direction=None): + #TODO "ANY_VAR" does not work yet + + def __init__(self, dtype, stride, direction=None, variable='ANY_VAR'): self.dtype = dtype self.stride = stride self.direction = direction + self.variable = variable def __eq__(self, other): return isinstance(other, StridedGmemAccess) and ( other.dtype == self.dtype and other.stride == self.stride and - other.direction == self.direction ) + other.direction == self.direction and + ((self.variable == 'ANY_VAR' or other.variable == 'ANY_VAR') or + self.variable == other.variable)) def __hash__(self): - if self.direction == None: - return hash(str(self.dtype)+str(self.stride)+"None") - else: - return hash(str(self.dtype)+str(self.stride)+self.direction) + direction = self.direction + variable = self.variable + if direction == None: + direction = 'None' + if variable == None: + variable = 'ANY_VAR' + return hash(str(self.dtype)+str(self.stride)+direction+variable) # {{{ ExpressionOpCounter @@ -310,7 +318,7 @@ class GlobalSubscriptCounter(CombineMapper): index = (index,) from loopy.symbolic import get_dependencies - from loopy.kernel.data import LocalIndexTag + from loopy.kernel.data import LocalIndexTag, GroupIndexTag my_inames = get_dependencies(index) & self.knl.all_inames() # find min tag axis @@ -327,7 +335,7 @@ class GlobalSubscriptCounter(CombineMapper): if not local_id_found: # count as uniform access return ToCountMap( - {StridedGmemAccess(self.type_inf(expr), 0): 1} + {StridedGmemAccess(self.type_inf(expr), 0, direction=None, variable=name): 1} ) + self.rec(expr.index) # get local_id associated with minimum tag axis @@ -342,8 +350,10 @@ class GlobalSubscriptCounter(CombineMapper): # found local_id associated with minimum tag axis total_stride = None - # check coefficient of local_id0 for each axis + extra_stride = 1 + # check coefficient of min_lid for each axis from loopy.symbolic import CoefficientCollector + from loopy.kernel.array import FixedStrideArrayDimTag from pymbolic.primitives import Variable for idx, axis_tag in zip(index, array.dim_tags): coeffs = CoefficientCollector()(idx) @@ -355,17 +365,22 @@ class GlobalSubscriptCounter(CombineMapper): continue # found coefficient of min_lid # now determine stride - from loopy.kernel.array import FixedStrideArrayDimTag if isinstance(axis_tag, FixedStrideArrayDimTag): stride = axis_tag.stride else: continue - total_stride = stride*coeff_min_lid + total_stride = stride*coeff_min_lid*extra_stride #TODO is there a case where this^ does not execute, or executes more than once for two different axes? + #TODO temporary fix that needs changing: + if min_tag_axis != 0: + print("...... min tag axis (%d) is not zero! ......" % (min_tag_axis)) + return ToCountMap({StridedGmemAccess(self.type_inf(expr), + sys.maxsize, direction=None, variable=name): 1}) + self.rec(expr.index) + return ToCountMap({StridedGmemAccess(self.type_inf(expr), - total_stride): 1}) + self.rec(expr.index) + total_stride, direction=None, variable=name): 1}) + self.rec(expr.index) def map_sum(self, expr): if expr.children: @@ -734,12 +749,12 @@ def get_gmem_access_poly(knl): # for now just counting subscripts subs_expr = subscript_counter(insn.expression) for key in subs_expr.dict: subs_expr.dict[StridedGmemAccess( - key.dtype, key.stride, 'load') + key.dtype, key.stride, direction='load', variable=key.variable) ] = subs_expr.dict.pop(key) subs_assignee = subscript_counter(insn.assignee) for key in subs_assignee.dict: subs_assignee.dict[StridedGmemAccess( - key.dtype, key.stride, 'store') + key.dtype, key.stride, direction='store', variable=key.variable) ] = subs_assignee.dict.pop(key) insn_inames = knl.insn_inames(insn)