From e16469d799a21f7cc5459a445c64bc05a21209cc Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@porter.cs.illinois.edu>
Date: Fri, 1 Apr 2016 17:35:36 -0500
Subject: [PATCH] adding local mem access counter

---
 loopy/statistics.py     | 188 ++++++++++++++++++++++++++++++++++++++++
 test/test_statistics.py |  35 ++++----
 2 files changed, 204 insertions(+), 19 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 3f2c3a4b5..419eb2868 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -120,6 +120,24 @@ class TypedOp:
         return hash(str(self.dtype)+self.name)
 
 
+class LmemAccess:
+
+    def __init__(self, dtype, direction=None):
+        self.dtype = dtype
+        self.direction = direction
+
+    def __eq__(self, other):
+        return isinstance(other, LmemAccess) and (
+                other.dtype == self.dtype and
+                other.direction == self.direction)
+
+    def __hash__(self):
+        direction = self.direction
+        if direction == None:
+            direction = 'None'
+        return hash(str(self.dtype)+direction)
+
+
 class StridedGmemAccess:
 
     #TODO "ANY_VAR" does not work yet
@@ -279,6 +297,116 @@ class ExpressionOpCounter(CombineMapper):
 # }}}
 
 
+# {{{ LocalSubscriptCounter
+
+class LocalSubscriptCounter(CombineMapper):
+
+    def __init__(self, knl):
+        self.knl = knl
+        from loopy.expression import TypeInferenceMapper
+        self.type_inf = TypeInferenceMapper(knl)
+
+    def combine(self, values):
+        return sum(values)
+
+    def map_constant(self, expr):
+        return ToCountMap()
+
+    map_tagged_variable = map_constant
+    map_variable = map_constant
+
+    def map_call(self, expr):
+        return self.rec(expr.parameters)
+
+    def map_subscript(self, expr):
+        name = expr.aggregate.name  # name of array
+
+        if name in self.knl.temporary_variables:
+            array = self.knl.temporary_variables[name]
+            #print("array: ", array)
+            #print("is local? ", array.is_local)
+            if array.is_local:
+                return ToCountMap(
+                        {LmemAccess(self.type_inf(expr), direction=None): 1}
+                        ) + self.rec(expr.index)
+
+        return self.rec(expr.index)
+            
+    def map_sum(self, expr):
+        if expr.children:
+            return sum(self.rec(child) for child in expr.children)
+        else:
+            return ToCountMap()
+
+    map_product = map_sum
+
+    def map_quotient(self, expr, *args):
+        return self.rec(expr.numerator) + self.rec(expr.denominator)
+
+    map_floor_div = map_quotient
+    map_remainder = map_quotient
+
+    def map_power(self, expr):
+        return self.rec(expr.base) + self.rec(expr.exponent)
+
+    def map_left_shift(self, expr):
+        return self.rec(expr.shiftee)+self.rec(expr.shift)
+
+    map_right_shift = map_left_shift
+
+    def map_bitwise_not(self, expr):
+        return self.rec(expr.child)
+
+    def map_bitwise_or(self, expr):
+        return sum(self.rec(child) for child in expr.children)
+
+    map_bitwise_xor = map_bitwise_or
+    map_bitwise_and = map_bitwise_or
+
+    def map_comparison(self, expr):
+        return self.rec(expr.left)+self.rec(expr.right)
+
+    map_logical_not = map_bitwise_not
+    map_logical_or = map_bitwise_or
+    map_logical_and = map_logical_or
+
+    def map_if(self, expr):
+        warnings.warn("LocalSubscriptCounter counting LMEM accesses as "
+                      "sum of if-statement branches.")
+        return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_)
+
+    def map_if_positive(self, expr):
+        warnings.warn("LocalSubscriptCounter counting LMEM accesses as "
+                      "sum of if_pos-statement branches.")
+        return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_)
+
+    map_min = map_bitwise_or
+    map_max = map_min
+
+    def map_common_subexpression(self, expr):
+        raise NotImplementedError("LocalSubscriptCounter encountered "
+                                  "common_subexpression, "
+                                  "map_common_subexpression not implemented.")
+
+    def map_substitution(self, expr):
+        raise NotImplementedError("LocalSubscriptCounter encountered "
+                                  "substitution, "
+                                  "map_substitution not implemented.")
+
+    def map_derivative(self, expr):
+        raise NotImplementedError("LocalSubscriptCounter encountered "
+                                  "derivative, "
+                                  "map_derivative not implemented.")
+
+    def map_slice(self, expr):
+        raise NotImplementedError("LocalSubscriptCounter encountered slice, "
+                                  "map_slice not implemented.")
+
+# }}}
+
+
+
+
 # {{{ GlobalSubscriptCounter
 
 class GlobalSubscriptCounter(CombineMapper):
@@ -674,6 +802,66 @@ def sum_ops_to_dtypes(op_poly_dict):
     return result
 
 
+def get_lmem_access_poly(knl):
+
+    """Count the number of local memory accesses in a loopy kernel.
+    """
+
+    from loopy.preprocess import preprocess_kernel, infer_unknown_types
+
+    class CacheHolder(object):
+        pass
+
+    cache_holder = CacheHolder()
+
+    @memoize_in(cache_holder, "insn_count")
+    def get_insn_count(knl, insn_inames):
+        inames_domain = knl.get_inames_domain(insn_inames)
+        domain = (inames_domain.project_out_except(
+                                insn_inames, [dim_type.set]))
+        return count(knl, domain)
+
+    knl = infer_unknown_types(knl, expect_completion=True)
+    knl = preprocess_kernel(knl)
+
+    subs_poly = ToCountMap()
+    subscript_counter = LocalSubscriptCounter(knl)
+    for insn in knl.instructions:
+        # count subscripts, distinguishing loads and stores
+        subs_expr = subscript_counter(insn.expression)
+        for key in subs_expr.dict:
+            subs_expr.dict[LmemAccess(
+                           key.dtype, direction='load')
+                          ] = subs_expr.dict.pop(key)
+        subs_assignee = subscript_counter(insn.assignee)
+        for key in subs_assignee.dict:
+            print(key.dtype, key.direction, subs_assignee.dict[key])
+
+        # for now, not counting stores in local mem
+        '''
+        for key in subs_assignee.dict:
+            subs_assignee.dict[LmemAccess(
+                               key.dtype, direction='store')
+                              ] = subs_assignee.dict.pop(key)
+        '''
+
+        insn_inames = knl.insn_inames(insn)
+
+        # use count excluding local index tags for uniform accesses
+        for key in subs_expr.dict:
+            poly = ToCountMap({key: subs_expr.dict[key]})
+            subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames)
+
+        # for now, not counting stores in local mem
+        '''
+        for key in subs_assignee.dict:
+            poly = ToCountMap({key: subs_assignee.dict[key]})
+            subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames)
+        '''
+
+    return subs_poly.dict
+
+
 # {{{ get_gmem_access_poly
 def get_gmem_access_poly(knl):  # for now just counting subscripts
 
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 56b9f4003..0353ac0d4 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -31,8 +31,10 @@ import loopy as lp
 from loopy.statistics import (
         get_op_poly,
         get_gmem_access_poly,
+        get_lmem_access_poly,
         get_barrier_poly,
         StridedGmemAccess,
+        LmemAccess,
         TypedOp)
 
 import numpy as np
@@ -578,6 +580,9 @@ def test_all_counters_parallel_matmul():
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
     knl = lp.split_iname(knl, "i", 16, outer_tag="g.0", inner_tag="l.1")
     knl = lp.split_iname(knl, "j", 16, outer_tag="g.1", inner_tag="l.0")
+    knl = lp.split_iname(knl, "k", 16)
+    knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"])
+    knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"])
 
     n = 512
     m = 256
@@ -585,7 +590,7 @@ def test_all_counters_parallel_matmul():
     params = {'n': n, 'm': m, 'l': l}
 
     barrier_count = get_barrier_poly(knl).eval_with_dict(params)
-    assert barrier_count == 0
+    assert barrier_count == 2*m/16
 
     op_map = get_op_poly(knl)
     f32mul = op_map[
@@ -602,35 +607,27 @@ def test_all_counters_parallel_matmul():
                         ].eval_with_dict(params)
 
     assert f32mul+f32add == n*m*l*2
-    assert i32ops == n*m*l*4 + l*n*4
 
     subscript_map = get_gmem_access_poly(knl)
-    #f32uncoal = subscript_map[StridedGmemAccess(
-    #                          np.dtype(np.float32), Variable('m'), direction='load', variable='ANY_VAR')
-    #                          ].eval_with_dict(params)
-    #test = StridedGmemAccess(np.dtype(np.float32), sys.maxsize, direction='load', variable='ANY_VAR')
-    #print("test key: ", test.dtype, test.stride, test.direction, test.variable)
-    #for key in subscript_map:
-    #    print(key.dtype, key.stride, key.direction, key.variable)
-    f32uncoal = subscript_map[StridedGmemAccess(
-                              np.dtype(np.float32), sys.maxsize, direction='load', variable='a')
-                              ].eval_with_dict(params)
-    '''
-    f32uncoal = subscript_map[StridedGmemAccess(
-                              np.dtype(np.float32), sys.maxsize, direction='load', variable='ANY_VAR')
-                              ].eval_with_dict(params)
-    '''
+
     f32coal = subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, direction='load', variable='b')
                             ].eval_with_dict(params)
+    f32coal += subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, direction='load', variable='a')
+                            ].eval_with_dict(params)
 
-    assert f32uncoal == n*m*l
-    assert f32coal == n*m*l
+    assert f32coal == n*m+m*l
 
     f32coal = subscript_map[StridedGmemAccess(np.dtype(np.float32), 1, direction='store', variable='c')
                             ].eval_with_dict(params)
 
     assert f32coal == n*l
 
+    local_subs_map = get_lmem_access_poly(knl)
+
+    local_subs_l = local_subs_map[LmemAccess(np.dtype(np.float32), direction='load')
+                                  ].eval_with_dict(params)
+
+    assert local_subs_l == n*m*l*2
 
 def test_gather_access_footprint():
     knl = lp.make_kernel(
-- 
GitLab