diff --git a/loopy/statistics.py b/loopy/statistics.py index 157eb70d592d8f1fbf0c5dce1526e1196e1aacae..fde8643bf92b7ad56bb47975fa7ede1bda9b399c 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -25,8 +25,6 @@ THE SOFTWARE. import six import loopy as lp -import numpy as np -import warnings from islpy import dim_type import islpy as isl from pytools import memoize_in @@ -319,7 +317,6 @@ class ToCountMap(object): return result - def sum(self): """Add all counts in ToCountMap. @@ -335,7 +332,6 @@ class ToCountMap(object): total += v return total - def eval_and_sum(self, params): """Add all counts in :class:`ToCountMap` and evaluate with provided parameter dict. @@ -443,7 +439,8 @@ class MemAccess(object): """ - def __init__(self, mtype=None, dtype=None, stride=None, direction=None, variable=None): + def __init__(self, mtype=None, dtype=None, stride=None, direction=None, + variable=None): self.mtype = mtype self.stride = stride self.direction = direction @@ -501,8 +498,8 @@ class MemAccess(object): variable = 'None' else: variable = self.variable - return "MemAccess("+mtype+", "+dtype+", "+stride+", "+direction+", " \ - +variable+")" + return "MemAccess(" + mtype + ", " + dtype + ", " + stride + ", " \ + + direction + ", " + variable + ")" # {{{ ExpressionOpCounter @@ -574,8 +571,8 @@ class ExpressionOpCounter(CombineMapper): def map_bitwise_or(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='bw'): - len(expr.children)-1} - ) + sum(self.rec(child) for child in expr.children) + len(expr.children)-1}) \ + + sum(self.rec(child) for child in expr.children) map_bitwise_xor = map_bitwise_or map_bitwise_and = map_bitwise_or @@ -596,8 +593,8 @@ class ExpressionOpCounter(CombineMapper): def map_min(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='maxmin'): - len(expr.children)-1} - ) + sum(self.rec(child) for child in expr.children) + len(expr.children)-1}) \ + + sum(self.rec(child) for child in expr.children) map_max = map_min @@ -739,7 +736,7 @@ class GlobalSubscriptCounter(CombineMapper): index = (index,) from loopy.symbolic import get_dependencies - from loopy.kernel.data import LocalIndexTag, GroupIndexTag + from loopy.kernel.data import LocalIndexTag my_inames = get_dependencies(index) & self.knl.all_inames() # find min tag axis @@ -758,7 +755,7 @@ class GlobalSubscriptCounter(CombineMapper): return ToCountMap({MemAccess(mtype='global', dtype=self.type_inf(expr), stride=0, variable=name): 1} - ) + self.rec(expr.index) + ) + self.rec(expr.index) if min_tag_axis != 0: warn_with_kernel(self.knl, "unknown_gmem_stride", @@ -768,7 +765,7 @@ class GlobalSubscriptCounter(CombineMapper): return ToCountMap({MemAccess(mtype='global', dtype=self.type_inf(expr), stride=sys.maxsize, variable=name): 1} - ) + self.rec(expr.index) + ) + self.rec(expr.index) # get local_id associated with minimum tag axis min_lid = None @@ -807,7 +804,7 @@ class GlobalSubscriptCounter(CombineMapper): return ToCountMap({MemAccess(mtype='global', dtype=self.type_inf(expr), stride=total_stride, variable=name): 1} - ) + self.rec(expr.index) + ) + self.rec(expr.index) def map_sum(self, expr): if expr.children: @@ -1203,8 +1200,7 @@ def get_mem_access_map(knl, numpy_types=True): if uniform: from loopy.kernel.data import LocalIndexTag insn_inames = [iname for iname in insn_inames if not - isinstance( - knl.iname_to_tag.get(iname), LocalIndexTag)] + isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except( insn_inames, [dim_type.set])) @@ -1227,7 +1223,7 @@ def get_mem_access_map(knl, numpy_types=True): subs_expr[MemAccess(mtype=key.mtype, dtype=key.dtype, stride=key.stride, direction='load', variable=key.variable) - ] = subs_expr.pop(key) + ] = subs_expr.pop(key) subs_assignee_g = subs_counter_g(insn.assignee) for key in subs_assignee_g.count_map: @@ -1235,7 +1231,7 @@ def get_mem_access_map(knl, numpy_types=True): stride=key.stride, direction='store', variable=key.variable) - ] = subs_assignee_g.pop(key) + ] = subs_assignee_g.pop(key) # for now, don't count writes to local mem insn_inames = knl.insn_inames(insn) @@ -1243,7 +1239,9 @@ def get_mem_access_map(knl, numpy_types=True): # use count excluding local index tags for uniform accesses for key in subs_expr.count_map: map = ToCountMap({key: subs_expr[key]}) - if key.mtype == 'global' and isinstance(key.stride, int) and key.stride == 0: + if (key.mtype == 'global' and + isinstance(key.stride, int) and + key.stride == 0): subs_map = subs_map \ + map*get_insn_count(knl, insn_inames, True) else: @@ -1264,8 +1262,8 @@ def get_mem_access_map(knl, numpy_types=True): dtype=mem_access.dtype.numpy_dtype, stride=mem_access.stride, direction=mem_access.direction, - variable=mem_access.variable) - , count) + variable=mem_access.variable), + count) for mem_access, count in six.iteritems(subs_map.count_map)) return subs_map