From c69368dc4ea751b49b38e1aacea30cd88a048462 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Wed, 23 Sep 2015 21:13:41 -0500
Subject: [PATCH] experimenting with reg counter

---
 loopy/statistics.py     | 206 ++++++++++++++++++++++++++++++++++++----
 test/test_statistics.py |   8 +-
 2 files changed, 194 insertions(+), 20 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 0799b8d0e..cd7ff8d79 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -396,16 +396,35 @@ class RegisterUsageEstimator(CombineMapper):
         self.knl = knl
         from loopy.expression import TypeInferenceMapper
         self.type_inf = TypeInferenceMapper(knl)
+        self.vars_found = []
+        self.subs_found = []
 
     def combine(self, values):
         return sum(values)
 
+    def forget_prev_vars(self):
+        del self.vars_found[:]
+
+    def forget_prev_subs(self):
+        del self.subs_found[:]
+
     def map_constant(self, expr):
         return 0
 
     def map_tagged_variable(self, expr):
-        return 1
-    map_variable = map_tagged_variable
+        print("tagged_var found, what to do here? ", expr) #TODO
+        return 0
+
+    def map_variable(self, expr):
+        #print("var: ", expr)
+        if expr in self.vars_found:
+            #print("var already in list") #TODO delete all these print statements
+            return 0
+        else:
+            self.vars_found.append(expr)
+            return 1
+
+    #map_variable = map_tagged_variable
     map_call = map_constant  # TODO what is this?
 
     def map_subscript(self, expr):
@@ -415,14 +434,29 @@ class RegisterUsageEstimator(CombineMapper):
             array = self.knl.arg_dict[name]
         else:
             # this is a temporary variable
-            return 1  # TODO +self.rec(expr.index)?
+            #print("subscript, temp var found: ", expr, expr.index, expr.aggregate)
+            #print((expr.index, expr.aggregate) in self.subs_found)
+            return 0
+            # TODO if this is not in shared, count it
+            if (expr.index, expr.aggregate) in self.subs_found:
+                return 0
+            else:
+                self.subs_found.append((expr.index, expr.aggregate))
+                return 1  # TODO +self.rec(expr.index)?
 
         if not isinstance(array, lp.GlobalArg):
             # this array is not in global memory
             return 0  # TODO is this right? recurse on index?
 
         # this is a global mem access
-        return 1  # TODO +self.rec(expr.index)?
+        #print("subscript, global var found: ", expr, expr.index, expr.aggregate)
+        #print((expr.index, expr.aggregate) in self.subs_found)
+        #self.subs_found.append((expr.index, expr.aggregate))
+        if (expr.index, expr.aggregate) in self.subs_found:
+            return 0
+        else:
+            self.subs_found.append((expr.index, expr.aggregate))
+            return 1  # TODO +self.rec(expr.index)?
 
     def map_sum(self, expr):
         if expr.children:
@@ -714,8 +748,12 @@ def get_barrier_poly(knl):
     return barrier_poly
 
 
-'''
 def get_regs_per_thread(knl):
+    return get_regs_per_thread4(knl)
+
+
+# map_var and map_tagged_var returned 1, no checking for any duplication
+def get_regs_per_thread1(knl):
 
     """Estimate registers per thread usage by a loopy kernel.
 
@@ -724,21 +762,48 @@ def get_regs_per_thread(knl):
     """
 
     from loopy.preprocess import preprocess_kernel, infer_unknown_types
-    from loopy.schedule import RunInstruction
+    from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
+    from operator import mul
     knl = infer_unknown_types(knl, expect_completion=True)
     knl = preprocess_kernel(knl)
     knl = lp.get_one_scheduled_kernel(knl)
-    regs = 0
+
+    max_regs = 0
+    current_loop_indices = 0
     reg_counter = RegisterUsageEstimator(knl)
+
+    #TODO test blocks vs lines
     for sched_item in knl.schedule:
-        if isinstance(sched_item, RunInstruction):
+        if isinstance(sched_item, EnterLoop):
+            # need to add indices to index count
+            # if counting by blocks, check current blk total vs max, save if bigger
+            if sched_item.iname:  # (if not empty)
+                current_loop_indices += 1  # TODO assumes all loops add 1 new index
+                #print("enter loop: ", sched_item)
+        elif isinstance(sched_item, LeaveLoop):
+            # need to subtract indices from index count
+            # if counting by blocks, check current blk total vs max, save if bigger
+            if sched_item.iname:  # (if not empty)
+                current_loop_indices -= 1  # TODO assumes all loops add 1 new index
+                #print("leave loop: ", sched_item)
+        elif isinstance(sched_item, RunInstruction):
+            # count regs for this instruction
+            # if counting by blocks, add to current block total
+            # if counting by lines, check current line total vs max, save if bigger
             insn = knl.id_to_insn[sched_item.insn_id]
-            regs += reg_counter(insn.expression) 
-            regs += reg_counter(insn.assignee) 
-    return regs
-'''
+            regs = current_loop_indices + \
+                   reg_counter(insn.assignee) + \
+                   reg_counter(insn.expression)
+            if regs > max_regs:
+                max_regs = regs
+            #print("RunInstruction, regs, max_regs ", sched_item, regs, max_regs)
+            # TODO check for iname reuse
+            # TODO don't count variables if they are loop indices?
 
-def get_regs_per_thread(knl):
+    return max_regs
+
+# no duplicate vars, subs
+def get_regs_per_thread2(knl):
 
     """Estimate registers per thread usage by a loopy kernel.
 
@@ -752,24 +817,27 @@ def get_regs_per_thread(knl):
     knl = infer_unknown_types(knl, expect_completion=True)
     knl = preprocess_kernel(knl)
     knl = lp.get_one_scheduled_kernel(knl)
+    #print(knl)
     max_regs = 0
     current_loop_indices = 0
     reg_counter = RegisterUsageEstimator(knl)
 
     #TODO test blocks vs lines
     for sched_item in knl.schedule:
+        reg_counter.forget_prev_vars()
+        reg_counter.forget_prev_subs()
         if isinstance(sched_item, EnterLoop):
             # need to add indices to index count
             # if counting by blocks, check current blk total vs max, save if bigger
             if sched_item.iname:  # (if not empty)
                 current_loop_indices += 1  # TODO assumes all loops add 1 new index
-                print("enter loop: ", sched_item)
+                #print("enter loop: ", sched_item)
         elif isinstance(sched_item, LeaveLoop):
             # need to subtract indices from index count
             # if counting by blocks, check current blk total vs max, save if bigger
             if sched_item.iname:  # (if not empty)
                 current_loop_indices -= 1  # TODO assumes all loops add 1 new index
-                print("leave loop: ", sched_item)
+                #print("leave loop: ", sched_item)
         elif isinstance(sched_item, RunInstruction):
             # count regs for this instruction
             # if counting by blocks, add to current block total
@@ -780,10 +848,116 @@ def get_regs_per_thread(knl):
                    reg_counter(insn.expression)
             if regs > max_regs:
                 max_regs = regs
-            print("RunInstruction, regs, max_regs ", sched_item, regs, max_regs)
+            #print("RunInstruction, regs, max_regs ", sched_item, regs, max_regs)
             # TODO check for iname reuse
             # TODO don't count variables if they are loop indices?
 
     return max_regs
 
+def get_regs_per_thread3(knl):
+
+    """Estimate registers per thread usage by a loopy kernel.
+
+    :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
+
+    """
+
+    from loopy.preprocess import preprocess_kernel, infer_unknown_types
+    from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
+    from operator import mul
+    knl = infer_unknown_types(knl, expect_completion=True)
+    knl = preprocess_kernel(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+    #print(knl)
+    max_regs = 0
+    current_loop_indices = 0
+    block_reg_totals = [0]
+    reg_counters = [RegisterUsageEstimator(knl)]
+    # multiple counters to track nested sets of previously used iname+index combinations
+
+    for sched_item in knl.schedule:
+        if isinstance(sched_item, EnterLoop):
+            if sched_item.iname:  # (if not empty)
+                #print("entering loop, totals: \n", block_reg_totals, max_regs) 
+                current_loop_indices += 1  # TODO assumes all loops add 1 new index
+                # start a new block total
+                block_reg_totals.append(current_loop_indices)
+                # start a new estimator
+                reg_counters.append(RegisterUsageEstimator(knl))
+                #print("entered loop, totals: \n", block_reg_totals, max_regs) 
+            else:
+                print("Error, how does this happen?")
+                1/0
+
+        elif isinstance(sched_item, LeaveLoop):
+            if sched_item.iname:  # (if not empty)
+                #print("leaving loop, totals: \n", block_reg_totals, max_regs) 
+                current_loop_indices -= 1  # TODO assumes all loops add 1 new index
+                if block_reg_totals[-1] > max_regs:
+                    max_regs = block_reg_totals[-1]
+                # pop to resume previous total
+                #block_reg_totals[-2] += block_reg_totals[-1]
+                block_reg_totals.pop()
+                reg_counters.pop()
+                #print("left loop, totals: \n", block_reg_totals, max_regs) 
+            else:
+                print("Error, how does this happen?")
+                1/0
+        elif isinstance(sched_item, RunInstruction):
+            insn = knl.id_to_insn[sched_item.insn_id]
+            #print("instruction found: ", insn) 
+            #print("pre insn totals: \n", block_reg_totals, max_regs) 
+            block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \
+                                    reg_counters[-1](insn.expression)
+            #print("post insn totals: \n", block_reg_totals, max_regs) 
+            # TODO don't count variables if they are loop indices? (also try this with ctr2)
+
+    #print("finished schedule, totals: \n", block_reg_totals, max_regs)
+    # finished looping, check outer block
+    if block_reg_totals[-1] > max_regs:
+        max_regs = block_reg_totals[-1]
+    #print("final, totals: \n", block_reg_totals, max_regs)
+
+    return max_regs
+
+#add all sub blocks to containing block
+def get_regs_per_thread4(knl):
+
+    """Estimate registers per thread usage by a loopy kernel.
+
+    :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
+
+    """
+
+    from loopy.preprocess import preprocess_kernel, infer_unknown_types
+    from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
+    from operator import mul
+    knl = infer_unknown_types(knl, expect_completion=True)
+    knl = preprocess_kernel(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+    #print(knl)
+
+    regs = 0
+    max_loop_indices = 0
+    current_loop_indices = 0
+    reg_counter = RegisterUsageEstimator(knl)
+
+    for sched_item in knl.schedule:
+        if isinstance(sched_item, EnterLoop):
+            if sched_item.iname:  # (if not empty)
+                current_loop_indices += 1  # TODO assumes all loops add 1 new index
+                if current_loop_indices > max_loop_indices:
+                    max_loop_indices = current_loop_indices
+                #print("enter loop: ", sched_item)
+        elif isinstance(sched_item, LeaveLoop):
+            # need to subtract indices from index count
+            if sched_item.iname:  # (if not empty)
+                current_loop_indices -= 1  # TODO assumes all loops add 1 new index
+                #print("leave loop: ", sched_item)
+        elif isinstance(sched_item, RunInstruction):
+            # count regs for this instruction
+            insn = knl.id_to_insn[sched_item.insn_id]
+            regs += reg_counter(insn.assignee) + \
+                   reg_counter(insn.expression)
 
+    return regs+max_loop_indices
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 8c45e12ba..985aeb1b9 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -526,7 +526,7 @@ def test_reg_counter_basic():
     knl = lp.add_and_infer_dtypes(knl,
                         dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     regs = get_regs_per_thread(knl)
-    assert regs == 7
+    assert regs == 9
 
 
 def test_reg_counter_reduction():
@@ -556,7 +556,7 @@ def test_reg_counter_logic():
 
     knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
     regs = get_regs_per_thread(knl)
-    assert regs == 11
+    assert regs == 7
 
 
 def test_reg_counter_specialops():
@@ -574,7 +574,7 @@ def test_reg_counter_specialops():
     knl = lp.add_and_infer_dtypes(knl,
                         dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     regs = get_regs_per_thread(knl)
-    assert regs == 6
+    assert regs == 9
 
 
 def test_reg_counter_bitwise():
@@ -594,7 +594,7 @@ def test_reg_counter_bitwise():
                 a=np.int32, b=np.int32,
                 g=np.int64, h=np.int64))
     regs = get_regs_per_thread(knl)
-    assert regs == 9
+    assert regs == 11
 
 
 def test_all_counters_parallel_matmul():
-- 
GitLab