diff --git a/loopy/statistics.py b/loopy/statistics.py
index b4a0b20b195a48f7767d57f76eba6e586c7b885a..85ac4c77bdad2573b6a7a033aa0e17221e010d59 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -32,7 +32,8 @@ from pymbolic.mapper import CombineMapper
 from functools import reduce
 
 
-class TypeToCountMap:
+class ToCountMap:
+    """Maps any type of key to an arithmetic type."""
 
     def __init__(self, init_dict=None):
         if init_dict is None:
@@ -46,24 +47,24 @@ class TypeToCountMap:
         for k, v in six.iteritems(other.dict):
             result[k] = self.dict.get(k, 0) + v
 
-        return TypeToCountMap(result)
+        return ToCountMap(result)
 
     def __radd__(self, other):
         if other != 0:
-            raise ValueError("TypeToCountMap: Attempted to add TypeToCountMap "
-                                "to {} {}. TypeToCountMap may only be added to "
-                                "0 and other TypeToCountMap objects."
+            raise ValueError("ToCountMap: Attempted to add ToCountMap "
+                                "to {} {}. ToCountMap may only be added to "
+                                "0 and other ToCountMap objects."
                                 .format(type(other), other))
             return
         return self
 
     def __mul__(self, other):
         if isinstance(other, isl.PwQPolynomial):
-            return TypeToCountMap({index: self.dict[index]*other
+            return ToCountMap({index: self.dict[index]*other
                                      for index in self.dict.keys()})
         else:
-            raise ValueError("TypeToCountMap: Attempted to multiply "
-                                "TypeToCountMap by {} {}."
+            raise ValueError("ToCountMap: Attempted to multiply "
+                                "ToCountMap by {} {}."
                                 .format(type(other), other))
 
     __rmul__ = __mul__
@@ -77,6 +78,9 @@ class TypeToCountMap:
     def __str__(self):
         return str(self.dict)
 
+    def __repr__(self):
+        return repr(self.dict)
+
 
 class ExpressionOpCounter(CombineMapper):
 
@@ -89,7 +93,7 @@ class ExpressionOpCounter(CombineMapper):
         return sum(values)
 
     def map_constant(self, expr):
-        return TypeToCountMap()
+        return ToCountMap()
 
     map_tagged_variable = map_constant
     map_variable = map_constant
@@ -111,16 +115,16 @@ class ExpressionOpCounter(CombineMapper):
 
     def map_sum(self, expr):
         if expr.children:
-            return TypeToCountMap(
+            return ToCountMap(
                         {self.type_inf(expr): len(expr.children)-1}
                         ) + sum(self.rec(child) for child in expr.children)
         else:
-            return TypeToCountMap()
+            return ToCountMap()
 
     map_product = map_sum
 
     def map_quotient(self, expr, *args):
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.numerator) \
                                 + self.rec(expr.denominator)
 
@@ -128,24 +132,24 @@ class ExpressionOpCounter(CombineMapper):
     map_remainder = map_quotient  # implemented in CombineMapper
 
     def map_power(self, expr):
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.base) \
                                 + self.rec(expr.exponent)
 
     def map_left_shift(self, expr):  # implemented in CombineMapper
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.shiftee) \
                                 + self.rec(expr.shift)
 
     map_right_shift = map_left_shift
 
     def map_bitwise_not(self, expr):  # implemented in CombineMapper
-        return TypeToCountMap({self.type_inf(expr): 1}) \
+        return ToCountMap({self.type_inf(expr): 1}) \
                                 + self.rec(expr.child)
 
     def map_bitwise_or(self, expr):
         # implemented in CombineMapper, maps to map_sum;
-        return TypeToCountMap(
+        return ToCountMap(
                         {self.type_inf(expr): len(expr.children)-1}
                         ) + sum(self.rec(child) for child in expr.children)
 
@@ -210,7 +214,7 @@ class ExpressionSubscriptCounter(CombineMapper):
         return sum(values)
 
     def map_constant(self, expr):
-        return TypeToCountMap()
+        return ToCountMap()
 
     map_tagged_variable = map_constant
     map_variable = map_constant
@@ -249,13 +253,13 @@ class ExpressionSubscriptCounter(CombineMapper):
 
         if not local_id_found:
             # count as uniform access
-            return TypeToCountMap(
+            return ToCountMap(
                     {(self.type_inf(expr), 'uniform'): 1}
                     ) + self.rec(expr.index)
 
         if local_id0 is None:
             # only non-zero local id(s) found, assume non-consecutive access
-            return TypeToCountMap(
+            return ToCountMap(
                     {(self.type_inf(expr), 'nonconsecutive'): 1}
                     ) + self.rec(expr.index)
 
@@ -274,7 +278,7 @@ class ExpressionSubscriptCounter(CombineMapper):
 
             if coeff_id0 != 1:
                 # non-consecutive access
-                return TypeToCountMap(
+                return ToCountMap(
                         {(self.type_inf(expr), 'nonconsecutive'): 1}
                         ) + self.rec(expr.index)
 
@@ -287,14 +291,14 @@ class ExpressionSubscriptCounter(CombineMapper):
 
             if stride != 1:
                 # non-consecutive
-                return TypeToCountMap(
+                return ToCountMap(
                         {(self.type_inf(expr), 'nonconsecutive'): 1}
                         ) + self.rec(expr.index)
 
             # else, stride == 1, continue since another idx could contain id0
 
         # loop finished without returning, stride==1 for every instance of local_id0
-        return TypeToCountMap(
+        return ToCountMap(
                 {(self.type_inf(expr), 'consecutive'): 1}
                 ) + self.rec(expr.index)
 
@@ -302,7 +306,7 @@ class ExpressionSubscriptCounter(CombineMapper):
         if expr.children:
             return sum(self.rec(child) for child in expr.children)
         else:
-            return TypeToCountMap()
+            return ToCountMap()
 
     map_product = map_sum
 
@@ -413,7 +417,7 @@ def get_op_poly(knl):
         insn_inames = knl.insn_inames(insn)
         inames_domain = knl.get_inames_domain(insn_inames)
         domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
-        ops = op_counter(insn.expression) + op_counter(insn.assignee)
+        ops = op_counter(insn.assignee) + op_counter(insn.expression)
         op_poly = op_poly + ops*count(knl, domain)
     return op_poly
 
@@ -429,8 +433,17 @@ def get_DRAM_access_poly(knl):  # for now just counting subscripts
         insn_inames = knl.insn_inames(insn)
         inames_domain = knl.get_inames_domain(insn_inames)
         domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
-        subs = subscript_counter(insn.expression) + subscript_counter(insn.assignee)
-        subs_poly = subs_poly + subs*count(knl, domain)
+        subs_expr = subscript_counter(insn.expression)
+        subs_expr = ToCountMap(dict(
+            (key + ("load",), val)
+            for key, val in six.iteritems(subs_expr.dict)))
+
+        subs_assignee = subscript_counter(insn.assignee)
+        subs_assignee = ToCountMap(dict(
+            (key + ("store",), val)
+            for key, val in six.iteritems(subs_assignee.dict)))
+
+        subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain)
     return subs_poly
 
 
diff --git a/test/test_statistics.py b/test/test_statistics.py
index a77c75cf8e2e7f1407c46cd059d32360269f2829..fa4ad2c05ec9db313b94b9ce762b0e90fec53894 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -204,13 +204,22 @@ def test_DRAM_access_counter_basic():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f64 = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f32 == 4*n*m*l
-    assert f64 == 3*n*m
+    assert f32 == 3*n*m*l
+    assert f64 == 2*n*m
+
+    f32 = poly.dict[
+                    (np.dtype(np.float32), 'uniform', 'store')
+                   ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                   ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == n*m*l
+    assert f64 == n*m
 
 
 def test_DRAM_access_counter_reduction():
@@ -228,9 +237,14 @@ def test_DRAM_access_counter_reduction():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f32 == 2*n*m*l+n*l
+    assert f32 == 2*n*m*l
+
+    f32 = poly.dict[
+                    (np.dtype(np.float32), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == n*l
 
 
 def test_DRAM_access_counter_logic():
@@ -250,13 +264,18 @@ def test_DRAM_access_counter_logic():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f64 = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     assert f32 == 2*n*m
-    assert f64 == 2*n*m
+    assert f64 == n*m
+
+    f64 = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f64 == n*m
 
 
 def test_DRAM_access_counter_specialops():
@@ -278,13 +297,22 @@ def test_DRAM_access_counter_specialops():
     m = 256
     l = 128
     f32 = poly.dict[
-                    (np.dtype(np.float32), 'uniform')
+                    (np.dtype(np.float32), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f64 = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f32 == 3*n*m*l
-    assert f64 == 3*n*m
+    assert f32 == 2*n*m*l
+    assert f64 == 2*n*m
+
+    f32 = poly.dict[
+                    (np.dtype(np.float32), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == n*m*l
+    assert f64 == n*m
 
 
 def test_DRAM_access_counter_bitwise():
@@ -309,9 +337,14 @@ def test_DRAM_access_counter_bitwise():
     m = 256
     l = 128
     i32 = poly.dict[
-                    (np.dtype(np.int32), 'uniform')
+                    (np.dtype(np.int32), 'uniform', 'load')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert i32 == 4*n*m+2*n*m*l
+
+    i32 = poly.dict[
+                    (np.dtype(np.int32), 'uniform', 'store')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert i32 == 5*n*m+3*n*m*l
+    assert i32 == n*m+n*m*l
 
 
 def test_DRAM_access_counter_mixed():
@@ -335,13 +368,22 @@ def test_DRAM_access_counter_mixed():
     m = 256
     l = 128
     f64uniform = poly.dict[
-                    (np.dtype(np.float64), 'uniform')
+                    (np.dtype(np.float64), 'uniform', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32nonconsec = poly.dict[
-                    (np.dtype(np.float32), 'nonconsecutive')
+                    (np.dtype(np.float32), 'nonconsecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f64uniform == 3*n*m
-    assert f32nonconsec == 4*n*m*l
+    assert f64uniform == 2*n*m
+    assert f32nonconsec == 3*n*m*l
+
+    f64uniform = poly.dict[
+                    (np.dtype(np.float64), 'uniform', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f32nonconsec = poly.dict[
+                    (np.dtype(np.float32), 'nonconsecutive', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f64uniform == n*m
+    assert f32nonconsec == n*m*l
 
 
 def test_DRAM_access_counter_nonconsec():
@@ -365,13 +407,22 @@ def test_DRAM_access_counter_nonconsec():
     m = 256
     l = 128
     f64nonconsec = poly.dict[
-                    (np.dtype(np.float64), 'nonconsecutive')
+                    (np.dtype(np.float64), 'nonconsecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32nonconsec = poly.dict[
-                    (np.dtype(np.float32), 'nonconsecutive')
+                    (np.dtype(np.float32), 'nonconsecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f64nonconsec == 3*n*m
-    assert f32nonconsec == 4*n*m*l
+    assert f64nonconsec == 2*n*m
+    assert f32nonconsec == 3*n*m*l
+
+    f64nonconsec = poly.dict[
+                    (np.dtype(np.float64), 'nonconsecutive', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f32nonconsec = poly.dict[
+                    (np.dtype(np.float32), 'nonconsecutive', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f64nonconsec == n*m
+    assert f32nonconsec == n*m*l
 
 
 def test_DRAM_access_counter_consec():
@@ -393,15 +444,24 @@ def test_DRAM_access_counter_consec():
     n = 512
     m = 256
     l = 128
-    print(poly)
+
     f64consec = poly.dict[
-                    (np.dtype(np.float64), 'consecutive')
+                    (np.dtype(np.float64), 'consecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32consec = poly.dict[
-                    (np.dtype(np.float32), 'consecutive')
+                    (np.dtype(np.float32), 'consecutive', 'load')
                     ].eval_with_dict({'n': n, 'm': m, 'l': l})
-    assert f64consec == 3*n*m
-    assert f32consec == 4*n*m*l
+    assert f64consec == 2*n*m
+    assert f32consec == 3*n*m*l
+
+    f64consec = poly.dict[
+                    (np.dtype(np.float64), 'consecutive', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f32consec = poly.dict[
+                    (np.dtype(np.float32), 'consecutive', 'store')
+                    ].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f64consec == n*m
+    assert f32consec == n*m*l
 
 
 def test_barrier_counter_nobarriers():
@@ -468,6 +528,7 @@ def test_all_counters_parallel_matmul():
     l = 128
 
     barrier_count = get_barrier_poly(knl).eval_with_dict({'n': n, 'm': n, 'l': n})
+    assert barrier_count == 0
 
     op_map = get_op_poly(knl)
     f32ops = op_map.dict[
@@ -477,20 +538,25 @@ def test_all_counters_parallel_matmul():
                         np.dtype(np.int32)
                         ].eval_with_dict({'n': n, 'm': m, 'l': l})
 
+    assert f32ops == n*m*l*2
+    assert i32ops == n*m*l*4 + l*n*4
+
     subscript_map = get_DRAM_access_poly(knl)
     f32uncoal = subscript_map.dict[
-                        (np.dtype(np.float32), 'nonconsecutive')
+                        (np.dtype(np.float32), 'nonconsecutive', 'load')
                         ].eval_with_dict({'n': n, 'm': m, 'l': l})
     f32coal = subscript_map.dict[
-                        (np.dtype(np.float32), 'consecutive')
+                        (np.dtype(np.float32), 'consecutive', 'load')
                         ].eval_with_dict({'n': n, 'm': m, 'l': l})
 
-    assert barrier_count == 0
-    assert f32ops == n*m*l*2
-    assert i32ops == n*m*l*4+l*n*4
     assert f32uncoal == n*m*l
-    assert f32coal == n*m*l+n*l
+    assert f32coal == n*m*l
 
+    f32coal = subscript_map.dict[
+                        (np.dtype(np.float32), 'consecutive', 'store')
+                        ].eval_with_dict({'n': n, 'm': m, 'l': l})
+
+    assert f32coal == n*l
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
@@ -498,4 +564,3 @@ if __name__ == "__main__":
     else:
         from py.test.cmdline import main
         main([__file__])
-