From b64d8af8de9fe309b8fa347fc83e1d06245aab19 Mon Sep 17 00:00:00 2001 From: James Stevens Date: Fri, 18 Nov 2016 12:50:27 -0600 Subject: [PATCH] better handling of case where min_tag_axis != 0 --- loopy/statistics.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index d5e4c43c9..6c9742e52 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -586,10 +586,10 @@ class ExpressionOpCounter(CombineMapper): def map_logical_not(self, expr): return self.rec(expr.child) - def map_logical_or(self, expr): - return sum(self.rec(child) for child in expr.children) + #def map_logical_or(self, expr): + # return sum(self.rec(child) for child in expr.children) - map_logical_and = map_logical_or + #map_logical_and = map_logical_or def map_if(self, expr): warnings.warn("ExpressionOpCounter counting ops as " @@ -796,6 +796,18 @@ class GlobalSubscriptCounter(CombineMapper): variable=name): 1} ) + self.rec(expr.index) + if min_tag_axis != 0: + warn_with_kernel(knl, "unknown_gmem_stride", + "GlobalSubscriptCounter: " + "Memory access minimum tag axis %d != 0, " + "stride unknown, using sys.maxsize." + % (min_tag_axis)) + #TODO switch all warnings to loopy warnings warn_with_kernel + return ToCountMap({MemAccess(mtype='global', + dtype=self.type_inf(expr), + stride=sys.maxsize, variable=name): 1} + ) + self.rec(expr.index) + # get local_id associated with minimum tag axis min_lid = None for iname in my_inames: @@ -807,8 +819,7 @@ class GlobalSubscriptCounter(CombineMapper): # found local_id associated with minimum tag axis - total_stride = None - extra_stride = 1 + total_stride = 0 # check coefficient of min_lid for each axis from loopy.symbolic import CoefficientCollector from loopy.kernel.array import FixedStrideArrayDimTag @@ -830,17 +841,7 @@ class GlobalSubscriptCounter(CombineMapper): else: continue - total_stride = stride*coeff_min_lid*extra_stride - #TODO is there a case where this^ does not execute, - # or executes more than once for two different axes? - - #TODO temporary fix that needs changing: - if min_tag_axis != 0: - print("... min tag axis (%d) is not zero! ..." % (min_tag_axis)) - return ToCountMap({MemAccess(mtype='global', - dtype=self.type_inf(expr), - stride=sys.maxsize, variable=name): 1} - ) + self.rec(expr.index) + total_stride += stride*coeff_min_lid return ToCountMap({MemAccess(mtype='global', dtype=self.type_inf(expr), stride=total_stride, variable=name): 1} -- GitLab