From 9a19d4f8059e9676b37bc99aad1c2fa192a69b75 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Wed, 7 Dec 2016 22:52:06 -0600
Subject: [PATCH] fixing flagged style problems

---
 test/test_statistics.py | 128 ++++++++++++++++++++--------------------
 1 file changed, 65 insertions(+), 63 deletions(-)

diff --git a/test/test_statistics.py b/test/test_statistics.py
index ed592842d..13f0474e8 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -33,6 +33,7 @@ import numpy as np
 
 from pymbolic.primitives import Variable
 
+
 def test_op_counter_basic():
 
     knl = lp.make_kernel(
@@ -235,25 +236,25 @@ def test_mem_access_counter_basic():
     params = {'n': n, 'm': m, 'l': l}
     f32l = mem_map[lp.MemAccess('global', np.float32,
                          stride=0, direction='load', variable='a')
-              ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     f32l += mem_map[lp.MemAccess('global', np.float32,
                           stride=0, direction='load', variable='b')
-               ].eval_with_dict(params)
+                    ].eval_with_dict(params)
     f64l = mem_map[lp.MemAccess('global', np.float64,
                          stride=0, direction='load', variable='g')
-              ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     f64l += mem_map[lp.MemAccess('global', np.float64,
                           stride=0, direction='load', variable='h')
-               ].eval_with_dict(params)
+                    ].eval_with_dict(params)
     assert f32l == 3*n*m*l
     assert f64l == 2*n*m
 
     f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32),
                          stride=0, direction='store', variable='c')
-              ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     f64s = mem_map[lp.MemAccess('global', np.dtype(np.float64),
                          stride=0, direction='store', variable='e')
-              ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     assert f32s == n*m*l
     assert f64s == n*m
 
@@ -275,21 +276,21 @@ def test_mem_access_counter_reduction():
     params = {'n': n, 'm': m, 'l': l}
     f32l = mem_map[lp.MemAccess('global', np.float32,
                          stride=0, direction='load', variable='a')
-              ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     f32l += mem_map[lp.MemAccess('global', np.float32,
                           stride=0, direction='load', variable='b')
-               ].eval_with_dict(params)
+                    ].eval_with_dict(params)
     assert f32l == 2*n*m*l
 
     f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32),
                          stride=0, direction='store', variable='c')
-              ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     assert f32s == n*l
 
     ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load']
-                             ).to_bytes().eval_and_sum(params)
+                                 ).to_bytes().eval_and_sum(params)
     st_bytes = mem_map.filter_by(mtype=['global'], direction=['store']
-                             ).to_bytes().eval_and_sum(params)
+                                 ).to_bytes().eval_and_sum(params)
     assert ld_bytes == 4*f32l
     assert st_bytes == 4*f32s
 
@@ -316,13 +317,13 @@ def test_mem_access_counter_logic():
 
     f32_g_l = reduced_map[lp.MemAccess('global', to_loopy_type(np.float32),
                                        direction='load')
-                         ].eval_with_dict(params)
+                          ].eval_with_dict(params)
     f64_g_l = reduced_map[lp.MemAccess('global', to_loopy_type(np.float64),
                                        direction='load')
-                         ].eval_with_dict(params)
+                          ].eval_with_dict(params)
     f64_g_s = reduced_map[lp.MemAccess('global', to_loopy_type(np.float64),
                                        direction='store')
-                         ].eval_with_dict(params)
+                          ].eval_with_dict(params)
     assert f32_g_l == 2*n*m
     assert f64_g_l == n*m
     assert f64_g_s == n*m
@@ -349,33 +350,34 @@ def test_mem_access_counter_specialops():
     params = {'n': n, 'm': m, 'l': l}
     f32 = mem_map[lp.MemAccess('global', np.float32,
                          stride=0, direction='load', variable='a')
-              ].eval_with_dict(params)
+                  ].eval_with_dict(params)
     f32 += mem_map[lp.MemAccess('global', np.float32,
                           stride=0, direction='load', variable='b')
-               ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     f64 = mem_map[lp.MemAccess('global', np.dtype(np.float64),
                          stride=0, direction='load', variable='g')
-              ].eval_with_dict(params)
+                  ].eval_with_dict(params)
     f64 += mem_map[lp.MemAccess('global', np.dtype(np.float64),
                           stride=0, direction='load', variable='h')
-               ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     assert f32 == 2*n*m*l
     assert f64 == 2*n*m
 
     f32 = mem_map[lp.MemAccess('global', np.float32,
                          stride=0, direction='store', variable='c')
-              ].eval_with_dict(params)
+                  ].eval_with_dict(params)
     f64 = mem_map[lp.MemAccess('global', np.float64,
                          stride=0, direction='store', variable='e')
-              ].eval_with_dict(params)
+                  ].eval_with_dict(params)
     assert f32 == n*m*l
     assert f64 == n*m
 
-    filtered_map = mem_map.filter_by(direction=['load'], variable=['a','g'])
+    filtered_map = mem_map.filter_by(direction=['load'], variable=['a', 'g'])
     #tot = lp.eval_and_sum_polys(filtered_map, params)
     tot = filtered_map.eval_and_sum(params)
     assert tot == n*m*l + n*m
 
+
 def test_mem_access_counter_bitwise():
 
     knl = lp.make_kernel(
@@ -400,24 +402,24 @@ def test_mem_access_counter_bitwise():
     params = {'n': n, 'm': m, 'l': l}
     i32 = mem_map[lp.MemAccess('global', np.int32,
                          stride=0, direction='load', variable='a')
-              ].eval_with_dict(params)
+                  ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.int32,
                           stride=0, direction='load', variable='b')
-               ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.int32,
                           stride=0, direction='load', variable='g')
-               ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.dtype(np.int32),
                           stride=0, direction='load', variable='h')
-               ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     assert i32 == 4*n*m+2*n*m*l
 
     i32 = mem_map[lp.MemAccess('global', np.int32,
                          stride=0, direction='store', variable='c')
-              ].eval_with_dict(params)
+                  ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.int32,
                           stride=0, direction='store', variable='e')
-               ].eval_with_dict(params)
+                   ].eval_with_dict(params)
     assert i32 == n*m+n*m*l
 
 
@@ -446,32 +448,32 @@ def test_mem_access_counter_mixed():
     params = {'n': n, 'm': m, 'l': l}
     f64uniform = mem_map[lp.MemAccess('global', np.float64,
                                 stride=0, direction='load', variable='g')
-                     ].eval_with_dict(params)
+                         ].eval_with_dict(params)
     f64uniform += mem_map[lp.MemAccess('global', np.float64,
                                  stride=0, direction='load', variable='h')
-                      ].eval_with_dict(params)
+                          ].eval_with_dict(params)
     f32uniform = mem_map[lp.MemAccess('global', np.float32,
                                 stride=0, direction='load', variable='x')
-                     ].eval_with_dict(params)
+                         ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.dtype(np.float32),
                                   stride=Variable('m'), direction='load',
                                   variable='a')
-                       ].eval_with_dict(params)
+                           ].eval_with_dict(params)
     f32nonconsec += mem_map[lp.MemAccess('global', np.dtype(np.float32),
                                    stride=Variable('m'), direction='load',
                                    variable='b')
-                        ].eval_with_dict(params)
+                            ].eval_with_dict(params)
     assert f64uniform == 2*n*m
     assert f32uniform == n*m*l/threads
     assert f32nonconsec == 3*n*m*l
 
     f64uniform = mem_map[lp.MemAccess('global', np.float64,
                                 stride=0, direction='store', variable='e')
-                     ].eval_with_dict(params)
+                         ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.float32,
                                   stride=Variable('m'), direction='store',
                                   variable='c')
-                       ].eval_with_dict(params)
+                           ].eval_with_dict(params)
     assert f64uniform == n*m
     assert f32nonconsec == n*m*l
 
@@ -500,30 +502,30 @@ def test_mem_access_counter_nonconsec():
     f64nonconsec = mem_map[lp.MemAccess('global', np.float64,
                                   stride=Variable('m'), direction='load',
                                   variable='g')
-                       ].eval_with_dict(params)
+                           ].eval_with_dict(params)
     f64nonconsec += mem_map[lp.MemAccess('global', np.float64,
                                    stride=Variable('m'), direction='load',
                                    variable='h')
-                        ].eval_with_dict(params)
+                            ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.dtype(np.float32),
                                   stride=Variable('m')*Variable('l'),
                                   direction='load', variable='a')
-                       ].eval_with_dict(params)
+                           ].eval_with_dict(params)
     f32nonconsec += mem_map[lp.MemAccess('global', np.dtype(np.float32),
                                    stride=Variable('m')*Variable('l'),
                                    direction='load', variable='b')
-                        ].eval_with_dict(params)
+                            ].eval_with_dict(params)
     assert f64nonconsec == 2*n*m
     assert f32nonconsec == 3*n*m*l
 
     f64nonconsec = mem_map[lp.MemAccess('global', np.float64,
                                   stride=Variable('m'), direction='store',
                                   variable='e')
-                       ].eval_with_dict(params)
+                           ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.float32,
                                   stride=Variable('m')*Variable('l'),
                                   direction='store', variable='c')
-                       ].eval_with_dict(params)
+                           ].eval_with_dict(params)
     assert f64nonconsec == n*m
     assert f32nonconsec == n*m*l
 
@@ -549,30 +551,27 @@ def test_mem_access_counter_consec():
     l = 128
     params = {'n': n, 'm': m, 'l': l}
 
-    #for k in mem_map:
-    #    print(k.mtype, k.dtype, type(k.dtype), k.stride, k.direction, k.variable, " :\n", mem_map[k])
-
     f64consec = mem_map[lp.MemAccess('global', np.float64,
                         stride=1, direction='load', variable='g')
-                     ].eval_with_dict(params)
+                        ].eval_with_dict(params)
     f64consec += mem_map[lp.MemAccess('global', np.float64,
                         stride=1, direction='load', variable='h')
-                     ].eval_with_dict(params)
+                         ].eval_with_dict(params)
     f32consec = mem_map[lp.MemAccess('global', np.float32,
                         stride=1, direction='load', variable='a')
-                     ].eval_with_dict(params)
+                        ].eval_with_dict(params)
     f32consec += mem_map[lp.MemAccess('global', np.dtype(np.float32),
                         stride=1, direction='load', variable='b')
-                     ].eval_with_dict(params)
+                         ].eval_with_dict(params)
     assert f64consec == 2*n*m
     assert f32consec == 3*n*m*l
 
     f64consec = mem_map[lp.MemAccess('global', np.float64,
                         stride=1, direction='store', variable='e')
-                     ].eval_with_dict(params)
+                        ].eval_with_dict(params)
     f32consec = mem_map[lp.MemAccess('global', np.float32,
                         stride=1, direction='store', variable='c')
-                     ].eval_with_dict(params)
+                        ].eval_with_dict(params)
     assert f64consec == n*m
     assert f32consec == n*m*l
 
@@ -671,26 +670,27 @@ def test_all_counters_parallel_matmul():
     op_map = lp.get_mem_access_map(knl)
 
     f32coal = op_map[lp.MemAccess('global', np.float32,
-                        stride=1, direction='load', variable='b')
-                            ].eval_with_dict(params)
+                     stride=1, direction='load', variable='b')
+                     ].eval_with_dict(params)
     f32coal += op_map[lp.MemAccess('global', np.float32,
-                        stride=1, direction='load', variable='a')
-                            ].eval_with_dict(params)
+                      stride=1, direction='load', variable='a')
+                      ].eval_with_dict(params)
 
     assert f32coal == n*m+m*l
 
     f32coal = op_map[lp.MemAccess('global', np.float32,
-                        stride=1, direction='store', variable='c')
-                            ].eval_with_dict(params)
+                     stride=1, direction='store', variable='c')
+                     ].eval_with_dict(params)
 
     assert f32coal == n*l
 
     local_mem_map = lp.get_mem_access_map(knl).filter_by(mtype=['local'])
     local_mem_l = local_mem_map[lp.MemAccess('local', np.dtype(np.float32),
-                                            direction='load')
-                                 ].eval_with_dict(params)
+                                             direction='load')
+                                ].eval_with_dict(params)
     assert local_mem_l == n*m*l*2
 
+
 def test_gather_access_footprint():
     knl = lp.make_kernel(
             "{[i,k,j]: 0<=i,j,k<n}",
@@ -744,25 +744,27 @@ def test_summations_and_filters():
 
     mem_map = lp.get_mem_access_map(knl)
 
-    loads_a = mem_map.filter_by(direction=['load'], variable=['a']).eval_and_sum(params)
+    loads_a = mem_map.filter_by(direction=['load'], variable=['a']
+                                ).eval_and_sum(params)
     assert loads_a == 2*n*m*l
 
-    global_stores = mem_map.filter_by(mtype=['global'], direction=['store']).eval_and_sum(params)
+    global_stores = mem_map.filter_by(mtype=['global'], direction=['store']
+                                      ).eval_and_sum(params)
     assert global_stores == n*m*l + n*m
 
     ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load']
-                             ).to_bytes().eval_and_sum(params)
+                                 ).to_bytes().eval_and_sum(params)
     st_bytes = mem_map.filter_by(mtype=['global'], direction=['store']
-                             ).to_bytes().eval_and_sum(params)
+                                 ).to_bytes().eval_and_sum(params)
     assert ld_bytes == 4*n*m*l*3 + 8*n*m*2
     assert st_bytes == 4*n*m*l + 8*n*m
 
     # ignore stride and variable names in this map
     reduced_map = mem_map.group_by('mtype', 'dtype', 'direction')
     f32lall = reduced_map[lp.MemAccess('global', np.float32, direction='load')
-                         ].eval_with_dict(params)
+                          ].eval_with_dict(params)
     f64lall = reduced_map[lp.MemAccess('global', np.float64, direction='load')
-                         ].eval_with_dict(params)
+                          ].eval_with_dict(params)
     assert f32lall == 3*n*m*l
     assert f64lall == 2*n*m
 
-- 
GitLab