From 985ed6da9cf667e302dba0d296e67a4304b0b8f2 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Wed, 24 Jun 2015 10:16:24 -0500
Subject: [PATCH] added tests for DRAM access counter

---
 loopy/statistics.py     |  59 +++++++++++++---
 test/test_statistics.py | 145 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 195 insertions(+), 9 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 2fc5b37fa..be1c3818a 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -218,27 +218,68 @@ class SubscriptCounter(CombineMapper):
     map_call = map_constant
 
     def map_subscript(self, expr):
-        name = expr.aggregate.name
-        arg = self.knl.arg_dict.get(name)
+        name = expr.aggregate.name  # name of array
+
+        if name in self.knl.arg_dict:
+            array = self.knl.arg_dict[name]
+        else:
+            print("Why would this happen?")  # TODO
+            # recurse and return
+            return
+
+        if not isinstance(array, lp.GlobalArg):
+            print("Why would this happen?")  # TODO
+            # recurse and return
+            return
+
+        index = expr.index  # could be tuple or scalar index
+        if not isinstance(index, tuple):
+            index = (index,)
+
+        from loopy.symbolic import get_dependencies
+        my_inames = get_dependencies(index) & self.knl.all_inames()
+        # TODO when would dependencies not be a subset of all inames?
+
+        #print("my_inames: ", my_inames)
+        #print("iname_to_tag: ", self.knl.iname_to_tag)
+        for iname in my_inames:
+            # find local id0 through self.knl.index_to_tag
+            #print("iname: ", iname, "; tag: ", self.knl.iname_to_tag.get(iname))
+            # TODO why are there no tags?
+            pass
+
+        """
+        for dim_tag, axis_index in zip(index, array.dim_tags):
+            # check if he contains the lid 0 guy
+
+            # determine if stride 1
+
+            # find coefficient
+        """
+
         tv = self.knl.temporary_variables.get(name)
-        if arg is not None:
-            if isinstance(arg, lp.GlobalArg):
+
+        #print("\n")
+
+        if array is not None:
+            if isinstance(array, lp.GlobalArg):
                 # It's global memory
                 pass
         elif tv is not None:
             if tv.is_local:
                 # It's shared memory
                 pass
-        #return 1 + self.rec(expr.index)
+
         return TypeToOpCountMap(
                         {self.type_inf(expr): 1}
                         ) + self.rec(expr.index)
+        # TODO what about duplicate accesses that are sitting in registers?
 
     '''
     def map_subscript(self, expr):
         name = expr.aggregate.name
-        if name in self.kernel.arg_dict:
-            array = self.kernel.arg_dict[name]
+        if name in self.knl.arg_dict:
+            array = self.knl.arg_dict[name]
         else:
             ...
             # recurse and return
@@ -252,10 +293,10 @@ class SubscriptCounter(CombineMapper):
             index = (index,)
 
         from loopy.symbolic import get_dependencies
-        my_inames = get_dependencies(index) & self.kernel.all_inames()
+        my_inames = get_dependencies(index) & self.knl.all_inames()
 
         for iname in my_inames:
-            # find local id0 through self.kernel.index_to_tag
+            # find local id0 through self.knl.index_to_tag
 
         # If you don't have a local id0
         # -> not stride1 (for now)
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 5f276617d..b6d694e26 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -28,6 +28,7 @@ from pyopencl.tools import (  # noqa
         as pytest_generate_tests)
 import loopy as lp
 from loopy.statistics import get_op_poly  # noqa
+from loopy.statistics import get_DRAM_access_poly  # noqa
 import numpy as np
 
 
@@ -186,6 +187,150 @@ def test_op_counter_triangular_domain():
         assert flops == 78
 
 
+def test_DRAM_access_counter_basic():
+
+    knl = lp.make_kernel(
+            "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
+                e[i, k] = g[i,k]*h[i,k+1]
+                """
+            ],
+            name="weird", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl,
+                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+    poly = get_DRAM_access_poly(knl)
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == 3*n*m*l
+    assert f64 == 2*n*m
+
+
+def test_DRAM_access_counter_reduction():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                "c[i, j] = sum(k, a[i, k]*b[k, j])"
+            ],
+            name="matmul", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
+    poly = get_DRAM_access_poly(knl)
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == 2*n*m*l
+
+
+def test_DRAM_access_counter_logic():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                e[i,k] = if(not(k<l-2) and k>6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2)
+                """
+            ],
+            name="logic", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
+    poly = get_DRAM_access_poly(knl)
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == 2*n*m
+    assert f64 == n*m
+
+
+def test_DRAM_access_counter_specialops():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
+                e[i, k] = (1+g[i,k])**(1+h[i,k+1])
+                """
+            ],
+            name="specialops", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl,
+                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+    poly = get_DRAM_access_poly(knl)
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == 2*n*m*l
+    assert f64 == 2*n*m
+
+
+def test_DRAM_access_counter_bitwise():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1)
+                e[i, k] = (g[i,k] ^ k)*(~h[i,k+1]) + (g[i, k] << (h[i,k] >> k))
+                """
+            ],
+            name="bitwise", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(
+            knl, dict(
+                a=np.int32, b=np.int32,
+                g=np.int32, h=np.int32))
+
+    poly = get_DRAM_access_poly(knl)
+    n = 512
+    m = 256
+    l = 128
+    i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert i32 == 4*n*m+2*n*m*l
+
+'''
+def test_DRAM_access_counter_triangular_domain():
+
+    knl = lp.make_kernel(
+            "{[i,j]: 0<=i<n and 0<=j<m and i<j}",
+            """
+            a[i, j] = b[i,j] * 2
+            """,
+            name="triangle", assumptions="n,m >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl,
+            dict(b=np.float64))
+
+    expect_fallback = False
+    import islpy as isl
+    try:
+        isl.BasicSet.cardz
+    except AttributeError:
+        expect_fallback = True
+    else:
+        expect_fallback = False
+
+    poly = get_DRAM_access_poly(knl)[np.dtype(np.float64)]
+    value_dict = dict(m=13, n=200)
+    subscripts = poly.eval_with_dict(value_dict)
+
+    if expect_fallback:
+        assert subscripts == 144
+    else:
+        assert subscripts == 78  # TODO figure out why this test is broken
+'''
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab