From 6ff4c160140018a2d9b49bea37ba5d76e59cc9bf Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Sun, 12 Jul 2015 15:08:54 -0500
Subject: [PATCH] cleaned up barrier counter

---
 loopy/statistics.py     | 46 ++++++++++++++++++-----------------------
 test/test_statistics.py | 27 +++++++++++++++++++++++-
 2 files changed, 46 insertions(+), 27 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 8e1210c11..f5002fba0 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -447,41 +447,35 @@ def get_DRAM_access_poly(knl):  # for now just counting subscripts
 
 def get_barrier_poly(knl):
     from loopy.preprocess import preprocess_kernel, infer_unknown_types
+    from loopy.schedule import EnterLoop, LeaveLoop, Barrier
+    from operator import mul
     knl = infer_unknown_types(knl, expect_completion=True)
     knl = preprocess_kernel(knl)
-
     knl = lp.get_one_scheduled_kernel(knl)
-    loop_iters = [1]  # [isl.PwQPolynomial('[]->{ 1 }')]
-    barrier_poly = 0  # isl.PwQPolynomial('[]->{ 0 }')
-    from loopy.schedule import EnterLoop, LeaveLoop, Barrier
-    from operator import mul
-    print("TESTING... kernel sched: \n", knl.schedule)
+    iname_list = []
+    barrier_poly = isl.PwQPolynomial('{ 0 }')  # 0
+
     for sched_item in knl.schedule:
-        print("TESTING... sched_item: ", sched_item)
         if isinstance(sched_item, EnterLoop):
-            print("TESTING... iname: ", sched_item.iname)
-            ct = count(knl, (
-                            knl.get_inames_domain(sched_item.iname).
-                            project_out_except(sched_item.iname, [dim_type.set])
-                            ))
-            if ct is not None:
-                loop_iters.append(ct)
+            if sched_item.iname:  # (if not empty)
+                iname_list.append(sched_item.iname)
         elif isinstance(sched_item, LeaveLoop):
-            print("TESTING... iname: ", sched_item.iname)
-            ct = count(knl, (
-                            knl.get_inames_domain(sched_item.iname).
-                            project_out_except(sched_item.iname, [dim_type.set])
-                            ))
-            if ct is not None:
-                loop_iters.pop()
+            if sched_item.iname:  # (if not empty)
+                iname_list.pop()
         elif isinstance(sched_item, Barrier):
-            print("TESTING... I FOUND A BARRIER!!!")
-            barrier_poly += reduce(mul, loop_iters)
-        print("TESTING... current iter list: \n", loop_iters)
-        print("TESTING... current iter product: \n", reduce(mul, loop_iters))
+            if iname_list:  # (if iname_list is not empty)
+                ct = (count(knl, (
+                                knl.get_inames_domain(iname_list).
+                                project_out_except(iname_list, [dim_type.set])
+                                )), )
+                barrier_poly += reduce(mul, ct)
+            else:
+                barrier_poly += isl.PwQPolynomial('{ 1 }')
+    '''
     if not isinstance(barrier_poly, isl.PwQPolynomial):
-        # TODO figure out how to fix this
+        #TODO figure out better fix for this
         string = "{"+str(barrier_poly)+"}"
         return isl.PwQPolynomial(string)
+    '''
     return barrier_poly
 
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 189a12bc8..7a0494ce4 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -424,7 +424,32 @@ def test_barrier_counter_basic():
     l = 128
     barrier_count = poly.eval_with_dict({'n': n, 'm': m, 'l': l})
     assert barrier_count == 0
-    # TODO test kernels with barriers
+
+
+def test_barrier_counter():
+
+    knl = lp.make_kernel(
+            "[n,m,l] -> {[i,k,j]: 0<=i<50 and 1<=k<98 and 0<=j<10}",
+            [
+                """
+            c[i,j,k] = 2*a[i,j,k] {id=first}
+            e[i,j,k] = c[i,j,k+1]+c[i,j,k-1] {dep=first}
+            """
+            ], [
+                lp.TemporaryVariable("c", lp.auto, shape=(50, 10, 99)),
+                "..."
+            ],
+            name="weird2",
+            )
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.int32))
+    knl = lp.split_iname(knl, "k", 128, outer_tag="g.0", inner_tag="l.0")
+    poly = get_barrier_poly(knl)
+    n = 512
+    m = 256
+    l = 128
+    barrier_count = poly.eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert barrier_count == 1000
+    # TODO more barrier counting tests
 
 
 if __name__ == "__main__":
-- 
GitLab