From 4870c19050a0ba29a8432b569d397baa7317b944 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Fri, 16 Oct 2015 20:02:59 -0500
Subject: [PATCH] reg usage estimator in somewhat usable state

---
 loopy/statistics.py     | 149 +++++++++++++++++++---------------------
 test/test_statistics.py |   5 +-
 2 files changed, 72 insertions(+), 82 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 48fda3e04..d81050c65 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -550,23 +550,27 @@ class RegisterUsageEstimator(CombineMapper):
         return 0
     #'''
     def map_variable(self, expr):
+        name = expr.name
         if expr in self.vars_found:
             return 0
-        else:
-            self.vars_found.append(expr)
-            print("new var found: ", expr)
-            print("knl.temp_vars: \n", self.knl.temporary_variables)
-            print("found in temp_vars? ", expr.name in self.knl.temporary_variables)
-            print("found in inames? ", expr.name in self.knl.all_inames)
-            #print("knl.vars: \n", self.knl.variables)
-            if expr.name in self.knl.temporary_variables:
-                print("local? ", self.knl.temporary_variables[expr.name].is_local)
-                
-            #print("local? ", self.knl.temporary_variables[expr.name].is_local)
-            if "_dim_" in str(expr): #TODO how to remove block/thread size/id vars?
+
+        self.vars_found.append(expr)
+        if name in self.knl.temporary_variables:
+            if self.knl.temporary_variables[name].is_local:
+                print("found temp var with local tag, not counting: ", expr) #TODO remove after debug
                 return 0
             else:
                 return 1
+        elif name in self.knl.all_inames():
+            from loopy.kernel.data import AxisTag
+            if (self.knl.iname_to_tag.get(name) is None or
+                    not isinstance(self.knl.iname_to_tag.get(name), AxisTag)):
+                #TODO use more specific positive instead of negative
+                return 1
+            else:
+                return 0
+        else:
+            return 1
 
     map_tagged_variable = map_variable
 
@@ -917,6 +921,57 @@ def get_regs_per_thread(knl):
     return get_regs_per_thread3_2(knl)
 
 
+def get_regs_per_thread3_2(knl):
+
+    """Estimate registers per thread usage by a loopy kernel.
+
+    :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
+
+    """
+
+    from loopy.preprocess import preprocess_kernel, infer_unknown_types
+    from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
+    from operator import mul
+    knl = infer_unknown_types(knl, expect_completion=True)
+    knl = preprocess_kernel(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+    max_regs = 0
+    block_reg_totals = [0]
+    reg_counters = [RegisterUsageEstimator(knl)]
+    # multiple counters to track nested sets of previously used iname+index combinations
+
+    for sched_item in knl.schedule:
+        if isinstance(sched_item, EnterLoop):
+            if sched_item.iname:  # (if not empty)
+                block_reg_totals.append(0)
+                # start a new estimator
+                reg_counters.append(RegisterUsageEstimator(knl))
+            else:
+                print("Error, how does this happen?") #TODO
+                1/0
+
+        elif isinstance(sched_item, LeaveLoop):
+            if sched_item.iname:  # (if not empty)
+                if block_reg_totals[-1] > max_regs:
+                    max_regs = block_reg_totals[-1]
+                # pop to resume previous total
+                block_reg_totals.pop()
+                reg_counters.pop()
+            else:
+                print("Error, how does this happen?") #TODO
+                1/0
+        elif isinstance(sched_item, RunInstruction):
+            insn = knl.id_to_insn[sched_item.insn_id]
+            block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \
+                                    reg_counters[-1](insn.expression)
+
+    # finished looping, check outer block
+    if block_reg_totals[-1] > max_regs:
+        max_regs = block_reg_totals[-1]
+
+    return max_regs
+
+'''
 # map_var and map_tagged_var returned 1, no checking for any duplication
 def get_regs_per_thread1(knl):
 
@@ -1084,74 +1139,9 @@ def get_regs_per_thread3(knl):
     #print("final, totals: \n", block_reg_totals, max_regs)
 
     return max_regs
+'''
 
-def get_regs_per_thread3_2(knl):
-
-    """Estimate registers per thread usage by a loopy kernel.
-
-    :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
-
-    """
-
-    from loopy.preprocess import preprocess_kernel, infer_unknown_types
-    from loopy.schedule import EnterLoop, LeaveLoop, Barrier, RunInstruction
-    from operator import mul
-    knl = infer_unknown_types(knl, expect_completion=True)
-    knl = preprocess_kernel(knl)
-    knl = lp.get_one_scheduled_kernel(knl)
-    #print(knl)
-    max_regs = 0
-    #current_loop_indices = 0
-    block_reg_totals = [0]
-    reg_counters = [RegisterUsageEstimator(knl)]
-    # multiple counters to track nested sets of previously used iname+index combinations
-
-    for sched_item in knl.schedule:
-        if isinstance(sched_item, EnterLoop):
-            if sched_item.iname:  # (if not empty)
-                #print("entering loop, totals: \n", block_reg_totals, max_regs) 
-                #current_loop_indices += 1  # TODO assumes all loops add 1 new index
-                # start a new block total
-                #block_reg_totals.append(current_loop_indices)
-                block_reg_totals.append(0)
-                # start a new estimator
-                reg_counters.append(RegisterUsageEstimator(knl))
-                #print("entered loop, totals: \n", block_reg_totals, max_regs) 
-            else:
-                print("Error, how does this happen?")
-                1/0
-
-        elif isinstance(sched_item, LeaveLoop):
-            if sched_item.iname:  # (if not empty)
-                #print("leaving loop, totals: \n", block_reg_totals, max_regs) 
-                #current_loop_indices -= 1  # TODO assumes all loops add 1 new index
-                if block_reg_totals[-1] > max_regs:
-                    max_regs = block_reg_totals[-1]
-                # pop to resume previous total
-                #block_reg_totals[-2] += block_reg_totals[-1]
-                block_reg_totals.pop()
-                reg_counters.pop()
-                #print("left loop, totals: \n", block_reg_totals, max_regs) 
-            else:
-                print("Error, how does this happen?")
-                1/0
-        elif isinstance(sched_item, RunInstruction):
-            insn = knl.id_to_insn[sched_item.insn_id]
-            #print("instruction found: ", insn) 
-            #print("pre insn totals: \n", block_reg_totals, max_regs) 
-            block_reg_totals[-1] += reg_counters[-1](insn.assignee) + \
-                                    reg_counters[-1](insn.expression)
-            #print("post insn totals: \n", block_reg_totals, max_regs) 
-            # TODO don't count variables if they are loop indices? (also try this with ctr2)
-
-    #print("finished schedule, totals: \n", block_reg_totals, max_regs)
-    # finished looping, check outer block
-    if block_reg_totals[-1] > max_regs:
-        max_regs = block_reg_totals[-1]
-    #print("final, totals: \n", block_reg_totals, max_regs)
-
-    return max_regs
-
+'''
 #add all sub blocks to containing block
 #aka add everything together
 def get_regs_per_thread4(knl):
@@ -1194,3 +1184,4 @@ def get_regs_per_thread4(knl):
                    reg_counter(insn.expression)
 
     return regs+max_loop_indices
+'''
diff --git a/test/test_statistics.py b/test/test_statistics.py
index fedf119dd..3ae1139e6 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -671,7 +671,7 @@ def test_barrier_counter_barriers():
     barrier_count = poly.eval_with_dict(params)
     assert barrier_count == 50*10*2
 
-'''
+
 def test_reg_counter_basic():
 
     knl = lp.make_kernel(
@@ -687,7 +687,6 @@ def test_reg_counter_basic():
     knl = lp.add_and_infer_dtypes(knl,
                         dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     regs = get_regs_per_thread(knl)
-    1/0
     assert regs == 6
 
 
@@ -757,7 +756,7 @@ def test_reg_counter_bitwise():
                 g=np.int64, h=np.int64))
     regs = get_regs_per_thread(knl)
     assert regs == 6
-'''
+
 
 def test_all_counters_parallel_matmul():
 
-- 
GitLab