diff --git a/loopy/statistics.py b/loopy/statistics.py
index 2e061ec6d76505ad6373e3fad4462c0459545253..82f36a659d06404f19251701e5b21c55098ec530 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -32,7 +32,8 @@ from pymbolic.mapper import CombineMapper
 from functools import reduce
 
 
-class TypeToCountMap:
+class ToCountMap:
+    """Maps any type of key to an arithmetic type."""
 
     def __init__(self, init_dict=None):
         if init_dict is None:
@@ -46,24 +47,24 @@ class TypeToCountMap:
         for k, v in six.iteritems(other.dict):
             result[k] = self.dict.get(k, 0) + v
 
-        return TypeToCountMap(result)
+        return ToCountMap(result)
 
     def __radd__(self, other):
         if other != 0:
-            raise ValueError("TypeToCountMap: Attempted to add TypeToCountMap "
-                                "to {} {}. TypeToCountMap may only be added to "
-                                "0 and other TypeToCountMap objects."
+            raise ValueError("ToCountMap: Attempted to add ToCountMap "
+                                "to {} {}. ToCountMap may only be added to "
+                                "0 and other ToCountMap objects."
                                 .format(type(other), other))
             return
         return self
 
     def __mul__(self, other):
         if isinstance(other, isl.PwQPolynomial):
-            return TypeToCountMap({index: self.dict[index]*other
+            return ToCountMap({index: self.dict[index]*other
                                      for index in self.dict.keys()})
         else:
-            raise ValueError("TypeToCountMap: Attempted to multiply "
-                                "TypeToCountMap by {} {}."
+            raise ValueError("ToCountMap: Attempted to multiply "
+                                "ToCountMap by {} {}."
                                 .format(type(other), other))
 
     __rmul__ = __mul__
@@ -77,6 +78,9 @@ class TypeToCountMap:
     def __str__(self):
         return str(self.dict)
 
+    def __repr__(self):
+        return repr(self.dict)
+
 
 class ExpressionOpCounter(CombineMapper):
 
@@ -89,7 +93,7 @@ class ExpressionOpCounter(CombineMapper):
         return sum(values)
 
     def map_constant(self, expr):
-        return TypeToCountMap()
+        return ToCountMap()
 
     map_tagged_variable = map_constant
     map_variable = map_constant
@@ -111,16 +115,16 @@ class ExpressionOpCounter(CombineMapper):
 
     def map_sum(self, expr):
         if expr.children:
-            return TypeToCountMap(
+            return ToCountMap(
                         {self.type_inf(expr): len(expr.children)-1}
                         ) + sum(self.rec(child) for child in expr.children)
         else:
-            return TypeToCountMap()
+            return ToCountMap()
 
     map_product = map_sum
 
     def map_quotient(self, expr, *args):
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.numerator) \
                                 + self.rec(expr.denominator)
 
@@ -128,24 +132,24 @@ class ExpressionOpCounter(CombineMapper):
     map_remainder = map_quotient  # implemented in CombineMapper
 
     def map_power(self, expr):
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.base) \
                                 + self.rec(expr.exponent)
 
     def map_left_shift(self, expr):  # implemented in CombineMapper
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.shiftee) \
                                 + self.rec(expr.shift)
 
     map_right_shift = map_left_shift
 
     def map_bitwise_not(self, expr):  # implemented in CombineMapper
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.child)
 
     def map_bitwise_or(self, expr):
         # implemented in CombineMapper, maps to map_sum;
-        return TypeToCountMap(
+        return ToCountMap(
                         {self.type_inf(expr): len(expr.children)-1}
                         ) + sum(self.rec(child) for child in expr.children)
 
@@ -210,7 +214,7 @@ class ExpressionSubscriptCounter(CombineMapper):
         return sum(values)
 
     def map_constant(self, expr):
-        return TypeToCountMap()
+        return ToCountMap()
 
     map_tagged_variable = map_constant
     map_variable = map_constant
@@ -249,13 +253,13 @@ class ExpressionSubscriptCounter(CombineMapper):
 
         if not local_id_found:
             # count as uniform access
-            return TypeToCountMap(
+            return ToCountMap(
                     {(self.type_inf(expr), 'uniform'): 1}
                     ) + self.rec(expr.index)
 
         if local_id0 is None:
             # only non-zero local id(s) found, assume non-consecutive access
-            return TypeToCountMap(
+            return ToCountMap(
                     {(self.type_inf(expr), 'nonconsecutive'): 1}
                     ) + self.rec(expr.index)
 
@@ -274,7 +278,7 @@ class ExpressionSubscriptCounter(CombineMapper):
 
             if coeff_id0 != 1:
                 # non-consecutive access
-                return TypeToCountMap(
+                return ToCountMap(
                         {(self.type_inf(expr), 'nonconsecutive'): 1}
                         ) + self.rec(expr.index)
 
@@ -287,14 +291,14 @@ class ExpressionSubscriptCounter(CombineMapper):
 
             if stride != 1:
                 # non-consecutive
-                return TypeToCountMap(
+                return ToCountMap(
                         {(self.type_inf(expr), 'nonconsecutive'): 1}
                         ) + self.rec(expr.index)
 
             # else, stride == 1, continue since another idx could contain id0
 
         # loop finished without returning, stride==1 for every instance of local_id0
-        return TypeToCountMap(
+        return ToCountMap(
                 {(self.type_inf(expr), 'consecutive'): 1}
                 ) + self.rec(expr.index)
 
@@ -302,7 +306,7 @@ class ExpressionSubscriptCounter(CombineMapper):
         if expr.children:
             return sum(self.rec(child) for child in expr.children)
         else:
-            return TypeToCountMap()
+            return ToCountMap()
 
     map_product = map_sum
 
@@ -413,7 +417,7 @@ def get_op_poly(knl):
         insn_inames = knl.insn_inames(insn)
         inames_domain = knl.get_inames_domain(insn_inames)
         domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
-        ops = op_counter(insn.expression)
+        ops = op_counter(insn.assignee) + op_counter(insn.expression)
         op_poly = op_poly + ops*count(knl, domain)
     return op_poly
 
@@ -429,8 +433,18 @@ def get_DRAM_access_poly(knl):  # for now just counting subscripts
         insn_inames = knl.insn_inames(insn)
         inames_domain = knl.get_inames_domain(insn_inames)
         domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
-        subs = subscript_counter(insn.expression)
-        subs_poly = subs_poly + subs*count(knl, domain)
+
+        subs_expr = subscript_counter(insn.expression)
+        subs_expr = ToCountMap(dict(
+            (key + ("load",), val)
+            for key, val in six.iteritems(subs_expr.dict)))
+
+        subs_assignee = subscript_counter(insn.assignee)
+        subs_assignee = ToCountMap(dict(
+            (key + ("store",), val)
+            for key, val in six.iteritems(subs_assignee.dict)))
+
+        subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain)
     return subs_poly
 
 
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 17e0cc54359d2c5ae5a19042063b0c5a0603ca22..0937e208f502bc81ed84b202e9b7e8367f75b300 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2079,6 +2079,8 @@ def test_vectorize(ctx_factory):
     knl = lp.add_and_infer_dtypes(knl, dict(b=np.float32))
     knl = lp.split_arg_axis(knl, [("a", 0), ("b", 0)], 4,
             split_kwargs=dict(slabs=(0, 1)))
+    print(knl)
+    1/0
 
     knl = lp.tag_data_axes(knl, "a,b", "c,vec")
     ref_knl = knl
diff --git a/test/test_statistics.py b/test/test_statistics.py
index dc040864f4a0affe4f0356008d1f5ea46450f471..80c9738d68791fc892e4939148939d3db509d637 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -38,7 +38,7 @@ def test_op_counter_basic():
             [
                 """
                 c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
-                e[i, k] = g[i,k]*h[i,k+1]
+                e[i, k+1] = g[i,k]*h[i,k+1]
                 """
             ],
             name="basic", assumptions="n,m,l >= 1")
@@ -54,7 +54,7 @@ def test_op_counter_basic():
     i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 3*n*m*l
     assert f64 == n*m
-    assert i32 == n*m
+    assert i32 == n*m*2
 
 
 def test_op_counter_reduction():
@@ -204,14 +204,23 @@ def test_DRAM_access_counter_basic():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f64 = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 3*n*m*l
     assert f64 == 2*n*m
 
+    f32 = poly.dict[
+                    (np.dtype(np.float32), 'uniform', 'store')
+                   ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                   ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == n*m*l
+    assert f64 == n*m
+
 
 def test_DRAM_access_counter_reduction():
 
@@ -228,10 +237,15 @@ def test_DRAM_access_counter_reduction():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 2*n*m*l
 
+    f32 = poly.dict[
+                    (np.dtype(np.float32), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == n*l
+
 
 def test_DRAM_access_counter_logic():
 
@@ -250,14 +264,19 @@ def test_DRAM_access_counter_logic():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f64 = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 2*n*m
     assert f64 == n*m
 
+    f64 = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f64 == n*m
+
 
 def test_DRAM_access_counter_specialops():
 
@@ -278,14 +297,23 @@ def test_DRAM_access_counter_specialops():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f64 = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 2*n*m*l
     assert f64 == 2*n*m
 
+    f32 = poly.dict[
+                    (np.dtype(np.float32), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == n*m*l
+    assert f64 == n*m
+
 
 def test_DRAM_access_counter_bitwise():
 
@@ -309,10 +337,15 @@ def test_DRAM_access_counter_bitwise():
     m = 256
     l = 128
     i32 = poly.dict[
-                    (np.dtype(np.int32), 'uniform')
+                    (np.dtype(np.int32), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert i32 == 4*n*m+2*n*m*l
 
+    i32 = poly.dict[
+                    (np.dtype(np.int32), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert i32 == n*m+n*m*l
+
 
 def test_DRAM_access_counter_mixed():
 
@@ -335,14 +368,23 @@ def test_DRAM_access_counter_mixed():
     m = 256
     l = 128
     f64uniform = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32nonconsec = poly.dict[
-                    (np.dtype(np.float32), 'nonconsecutive')
+                    (np.dtype(np.float32), 'nonconsecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f64uniform == 2*n*m
     assert f32nonconsec == 3*n*m*l
 
+    f64uniform = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f32nonconsec = poly.dict[
+                    (np.dtype(np.float32), 'nonconsecutive', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f64uniform == n*m
+    assert f32nonconsec == n*m*l
+
 
 def test_DRAM_access_counter_nonconsec():
 
@@ -365,10 +407,10 @@ def test_DRAM_access_counter_nonconsec():
     m = 256
     l = 128
     f64nonconsec = poly.dict[
-                    (np.dtype(np.float64), 'nonconsecutive')
+                    (np.dtype(np.float64), 'nonconsecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32nonconsec = poly.dict[
-                    (np.dtype(np.float32), 'nonconsecutive')
+                    (np.dtype(np.float32), 'nonconsecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f64nonconsec == 2*n*m
     assert f32nonconsec == 3*n*m*l
@@ -395,10 +437,10 @@ def test_DRAM_access_counter_consec():
     l = 128
     print(poly)
     f64consec = poly.dict[
-                    (np.dtype(np.float64), 'consecutive')
+                    (np.dtype(np.float64), 'consecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32consec = poly.dict[
-                    (np.dtype(np.float32), 'consecutive')
+                    (np.dtype(np.float32), 'consecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f64consec == 2*n*m
     assert f32consec == 3*n*m*l
@@ -468,6 +510,7 @@ def test_all_counters_parallel_matmul():
     l = 128
 
     barrier_count = get_barrier_poly(knl).eval_with_dict({'n': n, 'm': n, 'l': n})
+    assert barrier_count == 0
 
     op_map = get_op_poly(knl)
     f32ops = op_map.dict[
@@ -477,20 +520,25 @@ def test_all_counters_parallel_matmul():
                         np.dtype(np.int32)
                         ].eval_with_dict({'n': n, 'm': m, 'l': l})
 
+    assert f32ops == n*m*l*2
+    assert i32ops == n*m*l*4 + l*n*4
+
     subscript_map = get_DRAM_access_poly(knl)
     f32uncoal = subscript_map.dict[
-                        (np.dtype(np.float32), 'nonconsecutive')
+                        (np.dtype(np.float32), 'nonconsecutive', 'load')
                         ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32coal = subscript_map.dict[
-                        (np.dtype(np.float32), 'consecutive')
+                        (np.dtype(np.float32), 'consecutive', 'load')
                         ].eval_with_dict({'n': n, 'm': m, 'l': l})
 
-    assert barrier_count == 0
-    assert f32ops == n*m*l*2
-    assert i32ops == n*m*l*4
     assert f32uncoal == n*m*l
     assert f32coal == n*m*l
 
+    f32coal = subscript_map.dict[
+                        (np.dtype(np.float32), 'consecutive', 'store')
+                        ].eval_with_dict({'n': n, 'm': m, 'l': l})
+
+    assert f32coal == n*l
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
@@ -498,4 +546,3 @@ if __name__ == "__main__":
     else:
         from py.test.cmdline import main
         main([__file__])
-