From 7d33b5c7d6d920ffee95429b7c9085b0afab8497 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Wed, 30 Sep 2015 21:34:27 -0500
Subject: [PATCH] experimenting with reg usage estimator

---
 loopy/statistics.py     | 118 +++++++++++++++++++++++++++++++---------
 test/test_statistics.py |  10 ++--
 2 files changed, 98 insertions(+), 30 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index cd7ff8d79..7b1a5c7da 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -410,19 +410,18 @@ class RegisterUsageEstimator(CombineMapper):
 
     def map_constant(self, expr):
         return 0
-
-    def map_tagged_variable(self, expr):
-        print("tagged_var found, what to do here? ", expr) #TODO
-        return 0
-
+    #'''
     def map_variable(self, expr):
-        #print("var: ", expr)
         if expr in self.vars_found:
-            #print("var already in list") #TODO delete all these print statements
             return 0
         else:
             self.vars_found.append(expr)
-            return 1
+            if "_dim_" in str(expr): #TODO how to remove block/thread size/id vars?
+                return 0
+            else:
+                return 1
+
+    map_tagged_variable = map_variable
 
     #map_variable = map_tagged_variable
     map_call = map_constant  # TODO what is this?
@@ -431,32 +430,33 @@ class RegisterUsageEstimator(CombineMapper):
         name = expr.aggregate.name  # name of array
 
         if name in self.knl.arg_dict:
+            # not a temporary variable
             array = self.knl.arg_dict[name]
+        elif self.knl.temporary_variables[name].is_local:
+            # temp var is in shared mem
+            return 0 + self.rec(expr.index)
+        elif (expr.index, expr.aggregate) in self.subs_found:
+            # temp var is NOT shared, but already counted
+            return 0 + self.rec(expr.index)
         else:
-            # this is a temporary variable
-            #print("subscript, temp var found: ", expr, expr.index, expr.aggregate)
-            #print((expr.index, expr.aggregate) in self.subs_found)
-            return 0
-            # TODO if this is not in shared, count it
-            if (expr.index, expr.aggregate) in self.subs_found:
-                return 0
-            else:
-                self.subs_found.append((expr.index, expr.aggregate))
-                return 1  # TODO +self.rec(expr.index)?
+            # temp var is NOT shared and NOT already counted
+            self.subs_found.append((expr.index, expr.aggregate))
+            return 1 + self.rec(expr.index)
+
+        # expr is not a temporary variable
 
         if not isinstance(array, lp.GlobalArg):
+            print("debug... When does this happen? ", expr, array)
+            1/0
             # this array is not in global memory
-            return 0  # TODO is this right? recurse on index?
+            return 1 + self.rec(expr.index)  # TODO
 
         # this is a global mem access
-        #print("subscript, global var found: ", expr, expr.index, expr.aggregate)
-        #print((expr.index, expr.aggregate) in self.subs_found)
-        #self.subs_found.append((expr.index, expr.aggregate))
         if (expr.index, expr.aggregate) in self.subs_found:
-            return 0
+            return 0 + self.rec(expr.index)
         else:
             self.subs_found.append((expr.index, expr.aggregate))
-            return 1  # TODO +self.rec(expr.index)?
+            return 1 + self.rec(expr.index)
 
     def map_sum(self, expr):
         if expr.children:
@@ -749,7 +749,7 @@ def get_barrier_poly(knl):
 
 
 def get_regs_per_thread(knl):
-    return get_regs_per_thread4(knl)
+    return get_regs_per_thread3_2(knl)
 
 
 # map_var and map_tagged_var returned 1, no checking for any duplication
@@ -920,7 +920,75 @@ def get_regs_per_thread3(knl):
 
     return max_regs
 
+def get_regs_per_thread3_2(knl):
+
+    """Estimate registers per thread usage by a loopy kernel.
+
+    :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
+
+    """
+
+    from loopy.preprocess import preprocess_kernel, infer_unknown_types
+    from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
+    from operator import mul
+    knl = infer_unknown_types(knl, expect_completion=True)
+    knl = preprocess_kernel(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+    #print(knl)
+    max_regs = 0
+    #current_loop_indices = 0
+    block_reg_totals = [0]
+    reg_counters = [RegisterUsageEstimator(knl)]
+    # multiple counters to track nested sets of previously used iname+index combinations
+
+    for sched_item in knl.schedule:
+        if isinstance(sched_item, EnterLoop):
+            if sched_item.iname:  # (if not empty)
+                #print("entering loop, totals: \n", block_reg_totals, max_regs) 
+                #current_loop_indices += 1  # TODO assumes all loops add 1 new index
+                # start a new block total
+                #block_reg_totals.append(current_loop_indices)
+                block_reg_totals.append(0)
+                # start a new estimator
+                reg_counters.append(RegisterUsageEstimator(knl))
+                #print("entered loop, totals: \n", block_reg_totals, max_regs) 
+            else:
+                print("Error, how does this happen?")
+                1/0
+
+        elif isinstance(sched_item, LeaveLoop):
+            if sched_item.iname:  # (if not empty)
+                #print("leaving loop, totals: \n", block_reg_totals, max_regs) 
+                #current_loop_indices -= 1  # TODO assumes all loops add 1 new index
+                if block_reg_totals[-1] > max_regs:
+                    max_regs = block_reg_totals[-1]
+                # pop to resume previous total
+                #block_reg_totals[-2] += block_reg_totals[-1]
+                block_reg_totals.pop()
+                reg_counters.pop()
+                #print("left loop, totals: \n", block_reg_totals, max_regs) 
+            else:
+                print("Error, how does this happen?")
+                1/0
+        elif isinstance(sched_item, RunInstruction):
+            insn = knl.id_to_insn[sched_item.insn_id]
+            #print("instruction found: ", insn) 
+            #print("pre insn totals: \n", block_reg_totals, max_regs) 
+            block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \
+                                    reg_counters[-1](insn.expression)
+            #print("post insn totals: \n", block_reg_totals, max_regs) 
+            # TODO don't count variables if they are loop indices? (also try this with ctr2)
+
+    #print("finished schedule, totals: \n", block_reg_totals, max_regs)
+    # finished looping, check outer block
+    if block_reg_totals[-1] > max_regs:
+        max_regs = block_reg_totals[-1]
+    #print("final, totals: \n", block_reg_totals, max_regs)
+
+    return max_regs
+
 #add all sub blocks to containing block
+#aka add everything together
 def get_regs_per_thread4(knl):
 
     """Estimate registers per thread usage by a loopy kernel.
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 985aeb1b9..7d89a521d 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -526,7 +526,7 @@ def test_reg_counter_basic():
     knl = lp.add_and_infer_dtypes(knl,
                         dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     regs = get_regs_per_thread(knl)
-    assert regs == 9
+    assert regs == 6
 
 
 def test_reg_counter_reduction():
@@ -540,7 +540,7 @@ def test_reg_counter_reduction():
 
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
     regs = get_regs_per_thread(knl)
-    assert regs == 7
+    assert regs == 6
 
 
 def test_reg_counter_logic():
@@ -556,7 +556,7 @@ def test_reg_counter_logic():
 
     knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
     regs = get_regs_per_thread(knl)
-    assert regs == 7
+    assert regs == 6
 
 
 def test_reg_counter_specialops():
@@ -574,7 +574,7 @@ def test_reg_counter_specialops():
     knl = lp.add_and_infer_dtypes(knl,
                         dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     regs = get_regs_per_thread(knl)
-    assert regs == 9
+    assert regs == 6
 
 
 def test_reg_counter_bitwise():
@@ -594,7 +594,7 @@ def test_reg_counter_bitwise():
                 a=np.int32, b=np.int32,
                 g=np.int64, h=np.int64))
     regs = get_regs_per_thread(knl)
-    assert regs == 11
+    assert regs == 6
 
 
 def test_all_counters_parallel_matmul():
-- 
GitLab