From e932d1a2bf15807b38972ea17efe52a3b4da7f6b Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Wed, 15 Jul 2015 15:16:15 -0500
Subject: [PATCH] added counting of assignee operations and subscripts

---
 loopy/statistics.py     |  4 ++--
 test/test_statistics.py | 34 +++++++++++++++++-----------------
 2 files changed, 19 insertions(+), 19 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 2e061ec6d..b4a0b20b1 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -413,7 +413,7 @@ def get_op_poly(knl):
         insn_inames = knl.insn_inames(insn)
         inames_domain = knl.get_inames_domain(insn_inames)
         domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
-        ops = op_counter(insn.expression)
+        ops = op_counter(insn.expression) + op_counter(insn.assignee)
         op_poly = op_poly + ops*count(knl, domain)
     return op_poly
 
@@ -429,7 +429,7 @@ def get_DRAM_access_poly(knl):  # for now just counting subscripts
         insn_inames = knl.insn_inames(insn)
         inames_domain = knl.get_inames_domain(insn_inames)
         domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
-        subs = subscript_counter(insn.expression)
+        subs = subscript_counter(insn.expression) + subscript_counter(insn.assignee)
         subs_poly = subs_poly + subs*count(knl, domain)
     return subs_poly
 
diff --git a/test/test_statistics.py b/test/test_statistics.py
index dc040864f..a77c75cf8 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -38,7 +38,7 @@ def test_op_counter_basic():
             [
                 """
                 c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
-                e[i, k] = g[i,k]*h[i,k+1]
+                e[i, k+1] = g[i,k]*h[i,k+1]
                 """
             ],
             name="basic", assumptions="n,m,l >= 1")
@@ -54,7 +54,7 @@ def test_op_counter_basic():
     i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 3*n*m*l
     assert f64 == n*m
-    assert i32 == n*m
+    assert i32 == n*m*2
 
 
 def test_op_counter_reduction():
@@ -209,8 +209,8 @@ def test_DRAM_access_counter_basic():
     f64 = poly.dict[
                     (np.dtype(np.float64), 'uniform')
                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f32 == 3*n*m*l
-    assert f64 == 2*n*m
+    assert f32 == 4*n*m*l
+    assert f64 == 3*n*m
 
 
 def test_DRAM_access_counter_reduction():
@@ -230,7 +230,7 @@ def test_DRAM_access_counter_reduction():
     f32 = poly.dict[
                     (np.dtype(np.float32), 'uniform')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f32 == 2*n*m*l
+    assert f32 == 2*n*m*l+n*l
 
 
 def test_DRAM_access_counter_logic():
@@ -256,7 +256,7 @@ def test_DRAM_access_counter_logic():
                     (np.dtype(np.float64), 'uniform')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 2*n*m
-    assert f64 == n*m
+    assert f64 == 2*n*m
 
 
 def test_DRAM_access_counter_specialops():
@@ -283,8 +283,8 @@ def test_DRAM_access_counter_specialops():
     f64 = poly.dict[
                     (np.dtype(np.float64), 'uniform')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f32 == 2*n*m*l
-    assert f64 == 2*n*m
+    assert f32 == 3*n*m*l
+    assert f64 == 3*n*m
 
 
 def test_DRAM_access_counter_bitwise():
@@ -311,7 +311,7 @@ def test_DRAM_access_counter_bitwise():
     i32 = poly.dict[
                     (np.dtype(np.int32), 'uniform')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert i32 == 4*n*m+2*n*m*l
+    assert i32 == 5*n*m+3*n*m*l
 
 
 def test_DRAM_access_counter_mixed():
@@ -340,8 +340,8 @@ def test_DRAM_access_counter_mixed():
     f32nonconsec = poly.dict[
                     (np.dtype(np.float32), 'nonconsecutive')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f64uniform == 2*n*m
-    assert f32nonconsec == 3*n*m*l
+    assert f64uniform == 3*n*m
+    assert f32nonconsec == 4*n*m*l
 
 
 def test_DRAM_access_counter_nonconsec():
@@ -370,8 +370,8 @@ def test_DRAM_access_counter_nonconsec():
     f32nonconsec = poly.dict[
                     (np.dtype(np.float32), 'nonconsecutive')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f64nonconsec == 2*n*m
-    assert f32nonconsec == 3*n*m*l
+    assert f64nonconsec == 3*n*m
+    assert f32nonconsec == 4*n*m*l
 
 
 def test_DRAM_access_counter_consec():
@@ -400,8 +400,8 @@ def test_DRAM_access_counter_consec():
     f32consec = poly.dict[
                     (np.dtype(np.float32), 'consecutive')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f64consec == 2*n*m
-    assert f32consec == 3*n*m*l
+    assert f64consec == 3*n*m
+    assert f32consec == 4*n*m*l
 
 
 def test_barrier_counter_nobarriers():
@@ -487,9 +487,9 @@ def test_all_counters_parallel_matmul():
 
     assert barrier_count == 0
     assert f32ops == n*m*l*2
-    assert i32ops == n*m*l*4
+    assert i32ops == n*m*l*4+l*n*4
     assert f32uncoal == n*m*l
-    assert f32coal == n*m*l
+    assert f32coal == n*m*l+n*l
 
 
 if __name__ == "__main__":
-- 
GitLab