diff --git a/doc/tutorial.rst b/doc/tutorial.rst
index 217e1ef7c323ca13f8a1aaf81e8ea30c08b784a7..4efc13de4bc93b1024ad4ccd40d4d1ef5395643b 100644
--- a/doc/tutorial.rst
+++ b/doc/tutorial.rst
@@ -1553,11 +1553,11 @@ information provided. Now we will count the operations:
 
     >>> op_map = lp.get_op_map(knl)
     >>> print(lp.stringify_stats_mapping(op_map))
-    Op(np:dtype('float32'), add) : ...
+    Op(np:dtype('float32'), add, workitem) : ...
 
 Each line of output will look roughly like::
 
-    Op(np:dtype('float32'), add) : [l, m, n] -> { l * m * n : l > 0 and m > 0 and n > 0 }
+    Op(np:dtype('float32'), add, workitem) : [l, m, n] -> { l * m * n : l > 0 and m > 0 and n > 0 }
 
 :func:`loopy.get_op_map` returns a :class:`loopy.ToCountMap` of **{**
 :class:`loopy.Op` **:** :class:`islpy.PwQPolynomial` **}**. A
@@ -1578,12 +1578,13 @@ One way to evaluate these polynomials is with :func:`islpy.eval_with_dict`:
 .. doctest::
 
     >>> param_dict = {'n': 256, 'm': 256, 'l': 8}
-    >>> f32add = op_map[lp.Op(np.float32, 'add')].eval_with_dict(param_dict)
-    >>> f32div = op_map[lp.Op(np.float32, 'div')].eval_with_dict(param_dict)
-    >>> f32mul = op_map[lp.Op(np.float32, 'mul')].eval_with_dict(param_dict)
-    >>> f64add = op_map[lp.Op(np.float64, 'add')].eval_with_dict(param_dict)
-    >>> f64mul = op_map[lp.Op(np.float64, 'mul')].eval_with_dict(param_dict)
-    >>> i32add = op_map[lp.Op(np.int32, 'add')].eval_with_dict(param_dict)
+    >>> from loopy.statistics import CountGranularity as CG
+    >>> f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(param_dict)
+    >>> f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(param_dict)
+    >>> f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(param_dict)
+    >>> f64add = op_map[lp.Op(np.float64, 'add', CG.WORKITEM)].eval_with_dict(param_dict)
+    >>> f64mul = op_map[lp.Op(np.float64, 'mul', CG.WORKITEM)].eval_with_dict(param_dict)
+    >>> i32add = op_map[lp.Op(np.int32, 'add', CG.WORKITEM)].eval_with_dict(param_dict)
     >>> print("%i\n%i\n%i\n%i\n%i\n%i" %
     ...     (f32add, f32div, f32mul, f64add, f64mul, i32add))
     524288
@@ -1614,7 +1615,7 @@ together into keys containing only the specified fields:
 
     >>> op_map_dtype = op_map.group_by('dtype')
     >>> print(lp.stringify_stats_mapping(op_map_dtype))
-    Op(np:dtype('float32'), None) : ...
+    Op(np:dtype('float32'), None, None) : ...
     <BLANKLINE>
     >>> f32op_count = op_map_dtype[lp.Op(dtype=np.float32)
     ...                           ].eval_with_dict(param_dict)
@@ -1623,8 +1624,8 @@ together into keys containing only the specified fields:
 
 The lines of output above might look like::
 
-    Op(np:dtype('float32'), None) : [m, l, n] -> { 3 * m * l * n : m > 0 and l > 0 and n > 0 }
-    Op(np:dtype('float64'), None) : [m, l, n] -> { 2 * m * n : m > 0 and l > 0 and n > 0 }
+    Op(np:dtype('float32'), None, None) : [m, l, n] -> { 3 * m * l * n : m > 0 and l > 0 and n > 0 }
+    Op(np:dtype('float64'), None, None) : [m, l, n] -> { 2 * m * n : m > 0 and l > 0 and n > 0 }
 
 See the reference page for :class:`loopy.ToCountMap` and :class:`loopy.Op` for
 more information on these functions.
@@ -1638,17 +1639,17 @@ we'll continue using the kernel from the previous example:
 
 .. doctest::
 
-    >>> mem_map = lp.get_mem_access_map(knl)
+    >>> mem_map = lp.get_mem_access_map(knl, subgroup_size=32)
     >>> print(lp.stringify_stats_mapping(mem_map))
-    MemAccess(global, np:dtype('float32'), 0, load, a) : ...
+    MemAccess(global, np:dtype('float32'), 0, load, a, subgroup) : ...
     <BLANKLINE>
 
 Each line of output will look roughly like::
 
 
-    MemAccess(global, np:dtype('float32'), 0, load, a) : [m, l, n] -> { 2 * m * l * n : m > 0 and l > 0 and n > 0 }
-    MemAccess(global, np:dtype('float32'), 0, load, b) : [m, l, n] -> { m * l * n : m > 0 and l > 0 and n > 0 }
-    MemAccess(global, np:dtype('float32'), 0, store, c) : [m, l, n] -> { m * l * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float32'), 0, load, a, subgroup) : [m, l, n] -> { 2 * m * l * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float32'), 0, load, b, subgroup) : [m, l, n] -> { m * l * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float32'), 0, store, c, subgroup) : [m, l, n] -> { m * l * n : m > 0 and l > 0 and n > 0 }
 
 :func:`loopy.get_mem_access_map` returns a :class:`loopy.ToCountMap` of **{**
 :class:`loopy.MemAccess` **:** :class:`islpy.PwQPolynomial` **}**.
@@ -1661,7 +1662,7 @@ Each line of output will look roughly like::
   data type accessed.
 
 - stride: An :class:`int` that specifies stride of the memory access. A stride
-  of 0 indicates a uniform access (i.e. all threads access the same item).
+  of 0 indicates a uniform access (i.e. all work-items access the same item).
 
 - direction: A :class:`str` that specifies the direction of memory access as
   **load** or **store**.
@@ -1673,13 +1674,13 @@ We can evaluate these polynomials using :func:`islpy.eval_with_dict`:
 
 .. doctest::
 
-    >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, 0, 'load', 'g')
+    >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, 0, 'load', 'g', CG.SUBGROUP)
     ...                  ].eval_with_dict(param_dict)
-    >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, 0, 'store', 'e')
+    >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, 0, 'store', 'e', CG.SUBGROUP)
     ...                  ].eval_with_dict(param_dict)
-    >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, 0, 'load', 'a')
+    >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, 0, 'load', 'a', CG.SUBGROUP)
     ...                  ].eval_with_dict(param_dict)
-    >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, 0, 'store', 'c')
+    >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, 0, 'store', 'c', CG.SUBGROUP)
     ...                  ].eval_with_dict(param_dict)
     >>> print("f32 ld a: %i\nf32 st c: %i\nf64 ld g: %i\nf64 st e: %i" %
     ...       (f32ld_a, f32st_c, f64ld_g, f64st_e))
@@ -1697,13 +1698,13 @@ using :func:`loopy.ToCountMap.to_bytes` and :func:`loopy.ToCountMap.group_by`:
 
     >>> bytes_map = mem_map.to_bytes()
     >>> print(lp.stringify_stats_mapping(bytes_map))
-    MemAccess(global, np:dtype('float32'), 0, load, a) : ...
+    MemAccess(global, np:dtype('float32'), 0, load, a, subgroup) : ...
     <BLANKLINE>
     >>> global_ld_st_bytes = bytes_map.filter_by(mtype=['global']
     ...                                         ).group_by('direction')
     >>> print(lp.stringify_stats_mapping(global_ld_st_bytes))
-    MemAccess(None, None, None, load, None) : ...
-    MemAccess(None, None, None, store, None) : ...
+    MemAccess(None, None, None, load, None, None) : ...
+    MemAccess(None, None, None, store, None, None) : ...
     <BLANKLINE>
     >>> loaded = global_ld_st_bytes[lp.MemAccess(direction='load')
     ...                            ].eval_with_dict(param_dict)
@@ -1715,12 +1716,12 @@ using :func:`loopy.ToCountMap.to_bytes` and :func:`loopy.ToCountMap.group_by`:
 
 The lines of output above might look like::
 
-    MemAccess(global, np:[m, l, n] -> { 8 * m * l * n : m > 0 and l > 0 and n > 0 }
-    MemAccess(global, np:dtype('float32'), 0, load, b) : [m, l, n] -> { 4 * m * l * n : m > 0 and l > 0 and n > 0 }
-    MemAccess(global, np:dtype('float32'), 0, store, c) : [m, l, n] -> { 4 * m * l * n : m > 0 and l > 0 and n > 0 }
-    MemAccess(global, np:dtype('float64'), 0, load, g) : [m, l, n] -> { 8 * m * n : m > 0 and l > 0 and n > 0 }
-    MemAccess(global, np:dtype('float64'), 0, load, h) : [m, l, n] -> { 8 * m * n : m > 0 and l > 0 and n > 0 }
-    MemAccess(global, np:dtype('float64'), 0, store, e) : [m, l, n] -> { 8 * m * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float32'), 0, load, a, subgroup) : [m, l, n] -> { 8 * m * l * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float32'), 0, load, b, subgroup) : [m, l, n] -> { 4 * m * l * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float32'), 0, store, c, subgroup) : [m, l, n] -> { 4 * m * l * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float64'), 0, load, g, subgroup) : [m, l, n] -> { 8 * m * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float64'), 0, load, h, subgroup) : [m, l, n] -> { 8 * m * n : m > 0 and l > 0 and n > 0 }
+    MemAccess(global, np:dtype('float64'), 0, store, e, subgroup) : [m, l, n] -> { 8 * m * n : m > 0 and l > 0 and n > 0 }
 
 One can see how these functions might be useful in computing, for example,
 achieved memory bandwidth in byte/sec or performance in FLOP/sec.
@@ -1728,7 +1729,7 @@ achieved memory bandwidth in byte/sec or performance in FLOP/sec.
 ~~~~~~~~~~~
 
 Since we have not tagged any of the inames or parallelized the kernel across
-threads (which would have produced iname tags), :func:`loopy.get_mem_access_map`
+work-items (which would have produced iname tags), :func:`loopy.get_mem_access_map`
 considers the memory accesses *uniform*, so the *stride* of each access is 0.
 Now we'll parallelize the kernel and count the array accesses again. The
 resulting :class:`islpy.PwQPolynomial` will be more complicated this time.
@@ -1737,30 +1738,30 @@ resulting :class:`islpy.PwQPolynomial` will be more complicated this time.
 
     >>> knl_consec = lp.split_iname(knl, "k", 128,
     ...                             outer_tag="l.1", inner_tag="l.0")
-    >>> mem_map = lp.get_mem_access_map(knl_consec)
+    >>> mem_map = lp.get_mem_access_map(knl_consec, subgroup_size=32)
     >>> print(lp.stringify_stats_mapping(mem_map))
-    MemAccess(global, np:dtype('float32'), 1, load, a) : ...
-    MemAccess(global, np:dtype('float32'), 1, load, b) : ...
-    MemAccess(global, np:dtype('float32'), 1, store, c) : ...
-    MemAccess(global, np:dtype('float64'), 1, load, g) : ...
-    MemAccess(global, np:dtype('float64'), 1, load, h) : ...
-    MemAccess(global, np:dtype('float64'), 1, store, e) : ...
+    MemAccess(global, np:dtype('float32'), 1, load, a, workitem) : ...
+    MemAccess(global, np:dtype('float32'), 1, load, b, workitem) : ...
+    MemAccess(global, np:dtype('float32'), 1, store, c, workitem) : ...
+    MemAccess(global, np:dtype('float64'), 1, load, g, workitem) : ...
+    MemAccess(global, np:dtype('float64'), 1, load, h, workitem) : ...
+    MemAccess(global, np:dtype('float64'), 1, store, e, workitem) : ...
     <BLANKLINE>
 
-With this parallelization, consecutive threads will access consecutive array
+With this parallelization, consecutive work-items will access consecutive array
 elements in memory. The polynomials are a bit more complicated now due to the
 parallelization, but when we evaluate them, we see that the total number of
 array accesses has not changed:
 
 .. doctest::
 
-    >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, 1, 'load', 'g')
+    >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, 1, 'load', 'g', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
-    >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, 1, 'store', 'e')
+    >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, 1, 'store', 'e', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
-    >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, 1, 'load', 'a')
+    >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, 1, 'load', 'a', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
-    >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, 1, 'store', 'c')
+    >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, 1, 'store', 'c', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
     >>> print("f32 ld a: %i\nf32 st c: %i\nf64 ld g: %i\nf64 st e: %i" %
     ...       (f32ld_a, f32st_c, f64ld_g, f64st_e))
@@ -1778,29 +1779,29 @@ switch the inner and outer tags in our parallelization of the kernel:
 
     >>> knl_nonconsec = lp.split_iname(knl, "k", 128,
     ...                                outer_tag="l.0", inner_tag="l.1")
-    >>> mem_map = lp.get_mem_access_map(knl_nonconsec)
+    >>> mem_map = lp.get_mem_access_map(knl_nonconsec, subgroup_size=32)
     >>> print(lp.stringify_stats_mapping(mem_map))
-    MemAccess(global, np:dtype('float32'), 128, load, a) : ...
-    MemAccess(global, np:dtype('float32'), 128, load, b) : ...
-    MemAccess(global, np:dtype('float32'), 128, store, c) : ...
-    MemAccess(global, np:dtype('float64'), 128, load, g) : ...
-    MemAccess(global, np:dtype('float64'), 128, load, h) : ...
-    MemAccess(global, np:dtype('float64'), 128, store, e) : ...
+    MemAccess(global, np:dtype('float32'), 128, load, a, workitem) : ...
+    MemAccess(global, np:dtype('float32'), 128, load, b, workitem) : ...
+    MemAccess(global, np:dtype('float32'), 128, store, c, workitem) : ...
+    MemAccess(global, np:dtype('float64'), 128, load, g, workitem) : ...
+    MemAccess(global, np:dtype('float64'), 128, load, h, workitem) : ...
+    MemAccess(global, np:dtype('float64'), 128, store, e, workitem) : ...
     <BLANKLINE>
 
-With this parallelization, consecutive threads will access *nonconsecutive*
+With this parallelization, consecutive work-items will access *nonconsecutive*
 array elements in memory. The total number of array accesses still has not
 changed:
 
 .. doctest::
 
-    >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, 128, 'load', 'g')
+    >>> f64ld_g = mem_map[lp.MemAccess('global', np.float64, 128, 'load', 'g', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
-    >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, 128, 'store', 'e')
+    >>> f64st_e = mem_map[lp.MemAccess('global', np.float64, 128, 'store', 'e', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
-    >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, 128, 'load', 'a')
+    >>> f32ld_a = mem_map[lp.MemAccess('global', np.float32, 128, 'load', 'a', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
-    >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, 128, 'store', 'c')
+    >>> f32st_c = mem_map[lp.MemAccess('global', np.float32, 128, 'store', 'c', CG.WORKITEM)
     ...                  ].eval_with_dict(param_dict)
     >>> print("f32 ld a: %i\nf32 st c: %i\nf64 ld g: %i\nf64 st e: %i" %
     ...       (f32ld_a, f32st_c, f64ld_g, f64st_e))
@@ -1827,7 +1828,7 @@ Counting synchronization events
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 :func:`loopy.get_synchronization_map` counts the number of synchronization
-events per **thread** in a kernel. First, we'll call this function on the
+events per **work-item** in a kernel. First, we'll call this function on the
 kernel from the previous example:
 
 .. doctest::
@@ -1885,8 +1886,8 @@ Now to make things more interesting, we'll create a kernel with barriers:
       }
     }
 
-In this kernel, when a thread performs the second instruction it uses data
-produced by *different* threads during the first instruction. Because of this,
+In this kernel, when a work-item performs the second instruction it uses data
+produced by *different* work-items during the first instruction. Because of this,
 barriers are required for correct execution, so loopy inserts them. Now we'll
 count the barriers using :func:`loopy.get_synchronization_map`:
 
@@ -1898,7 +1899,7 @@ count the barriers using :func:`loopy.get_synchronization_map`:
     kernel_launch : { 1 }
     <BLANKLINE>
 
-Based on the kernel code printed above, we would expect each thread to
+Based on the kernel code printed above, we would expect each work-item to
 encounter 50x10x2 barriers, which matches the result from
 :func:`loopy.get_synchronization_map`. In this case, the number of barriers
 does not depend on any inames, so we can pass an empty dictionary to
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 0f4697f92e3f779b5670147c0fe7936989a317c4..89683e0b466714700f18b090ec365d5861ea4d05 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -121,8 +121,8 @@ from loopy.transform.add_barrier import add_barrier
 from loopy.type_inference import infer_unknown_types
 from loopy.preprocess import preprocess_kernel, realize_reduction
 from loopy.schedule import generate_loop_schedules, get_one_scheduled_kernel
-from loopy.statistics import (ToCountMap, stringify_stats_mapping, Op,
-        MemAccess, get_op_poly, get_op_map, get_lmem_access_poly,
+from loopy.statistics import (ToCountMap, CountGranularity, stringify_stats_mapping,
+        Op, MemAccess, get_op_poly, get_op_map, get_lmem_access_poly,
         get_DRAM_access_poly, get_gmem_access_poly, get_mem_access_map,
         get_synchronization_poly, get_synchronization_map,
         gather_access_footprints, gather_access_footprint_bytes)
@@ -243,8 +243,8 @@ __all__ = [
         "PreambleInfo",
         "generate_code", "generate_code_v2", "generate_body",
 
-        "ToCountMap", "stringify_stats_mapping", "Op", "MemAccess",
-        "get_op_poly", "get_op_map", "get_lmem_access_poly",
+        "ToCountMap", "CountGranularity", "stringify_stats_mapping", "Op",
+        "MemAccess", "get_op_poly", "get_op_map", "get_lmem_access_poly",
         "get_DRAM_access_poly", "get_gmem_access_poly", "get_mem_access_map",
         "get_synchronization_poly", "get_synchronization_map",
         "gather_access_footprints", "gather_access_footprint_bytes",
diff --git a/loopy/statistics.py b/loopy/statistics.py
index a2dcb684620e5cceb821990b5085f5283d252af1..17c5bd3557bd65eddd2d9a35202a604c552e4e19 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -27,12 +27,12 @@ import six
 import loopy as lp
 from islpy import dim_type
 import islpy as isl
-from pytools import memoize_in
 from pymbolic.mapper import CombineMapper
 from functools import reduce
 from loopy.kernel.data import (
         MultiAssignmentBase, TemporaryVariable, temp_var_scope)
 from loopy.diagnostic import warn_with_kernel, LoopyError
+from pytools import Record
 
 
 __doc__ = """
@@ -40,6 +40,7 @@ __doc__ = """
 .. currentmodule:: loopy
 
 .. autoclass:: ToCountMap
+.. autoclass:: CountGranularity
 .. autoclass:: Op
 .. autoclass:: MemAccess
 
@@ -207,13 +208,13 @@ class ToCountMap(object):
     def filter_by(self, **kwargs):
         """Remove items without specified key fields.
 
-        :arg kwargs: Keyword arguments matching fields in the keys of
-                 the :class:`ToCountMap`, each given a list of
-                 allowable values for that key field.
+        :arg kwargs: Keyword arguments matching fields in the keys of the
+            :class:`ToCountMap`, each given a list of allowable values for that
+            key field.
 
         :return: A :class:`ToCountMap` containing the subset of the items in
-                 the original :class:`ToCountMap` that match the field values
-                 passed.
+            the original :class:`ToCountMap` that match the field values
+            passed.
 
         Example usage::
 
@@ -225,7 +226,7 @@ class ToCountMap(object):
                                              variable=['a','g'])
             tot_loads_a_g = filtered_map.eval_and_sum(params)
 
-            # (now use these counts to predict performance)
+            # (now use these counts to, e.g., predict performance)
 
         """
 
@@ -255,11 +256,11 @@ class ToCountMap(object):
     def filter_by_func(self, func):
         """Keep items that pass a test.
 
-        :arg func: A function that takes a map key a parameter and
-             returns a :class:`bool`.
+        :arg func: A function that takes a map key a parameter and returns a
+            :class:`bool`.
 
-        :arg: A :class:`ToCountMap` containing the subset of the items in
-                 the original :class:`ToCountMap` for which func(key) is true.
+        :arg: A :class:`ToCountMap` containing the subset of the items in the
+            original :class:`ToCountMap` for which func(key) is true.
 
         Example usage::
 
@@ -273,7 +274,7 @@ class ToCountMap(object):
             filtered_map = mem_map.filter_by_func(filter_func)
             tot = filtered_map.eval_and_sum(params)
 
-            # (now use these counts to predict performance)
+            # (now use these counts to, e.g., predict performance)
 
         """
 
@@ -288,13 +289,13 @@ class ToCountMap(object):
 
     def group_by(self, *args):
         """Group map items together, distinguishing by only the key fields
-           passed in args.
+        passed in args.
 
         :arg args: Zero or more :class:`str` fields of map keys.
 
-        :return: A :class:`ToCountMap` containing the same total counts
-                 grouped together by new keys that only contain the fields
-                 specified in the arguments passed.
+        :return: A :class:`ToCountMap` containing the same total counts grouped
+            together by new keys that only contain the fields specified in the
+            arguments passed.
 
         Example usage::
 
@@ -328,7 +329,7 @@ class ToCountMap(object):
             f64ops = ops_dtype[Op(dtype=np.float64)].eval_with_dict(params)
             i32ops = ops_dtype[Op(dtype=np.int32)].eval_with_dict(params)
 
-            # (now use these counts to predict performance)
+            # (now use these counts to, e.g., predict performance)
 
         """
 
@@ -361,9 +362,9 @@ class ToCountMap(object):
     def to_bytes(self):
         """Convert counts to bytes using data type in map key.
 
-        :return: A :class:`ToCountMap` mapping each original key to a
-                 :class:`islpy.PwQPolynomial` with counts in bytes rather than
-                 instances.
+        :return: A :class:`ToCountMap` mapping each original key to an
+            :class:`islpy.PwQPolynomial` with counts in bytes rather than
+            instances.
 
         Example usage::
 
@@ -385,7 +386,7 @@ class ToCountMap(object):
                                 mtype=['global'], stride=[2],
                                 direction=['store']).eval_and_sum(params)
 
-            # (now use these counts to predict performance)
+            # (now use these counts to, e.g., predict performance)
 
         """
 
@@ -403,8 +404,8 @@ class ToCountMap(object):
     def sum(self):
         """Add all counts in ToCountMap.
 
-        :return: A :class:`islpy.PwQPolynomial` or :class:`int` containing the sum of
-                 counts.
+        :return: An :class:`islpy.PwQPolynomial` or :class:`int` containing the
+            sum of counts.
 
         """
 
@@ -430,7 +431,7 @@ class ToCountMap(object):
         parameter dict.
 
         :return: An :class:`int` containing the sum of all counts in the
-                 :class:`ToCountMap` evaluated with the parameters provided.
+            :class:`ToCountMap` evaluated with the parameters provided.
 
         Example usage::
 
@@ -442,7 +443,7 @@ class ToCountMap(object):
                                              variable=['a','g'])
             tot_loads_a_g = filtered_map.eval_and_sum(params)
 
-            # (now use these counts to predict performance)
+            # (now use these counts to, e.g., predict performance)
 
         """
         return self.sum().eval_with_dict(params)
@@ -457,9 +458,36 @@ def stringify_stats_mapping(m):
     return result
 
 
+class CountGranularity:
+    """Strings specifying whether an operation should be counted once per
+    *work-item*, *sub-group*, or *work-group*.
+
+    .. attribute:: WORKITEM
+
+       A :class:`str` that specifies that an operation should be counted
+       once per *work-item*.
+
+    .. attribute:: SUBGROUP
+
+       A :class:`str` that specifies that an operation should be counted
+       once per *sub-group*.
+
+    .. attribute:: WORKGROUP
+
+       A :class:`str` that specifies that an operation should be counted
+       once per *work-group*.
+
+    """
+
+    WORKITEM = "workitem"
+    SUBGROUP = "subgroup"
+    WORKGROUP = "workgroup"
+    ALL = [WORKITEM, SUBGROUP, WORKGROUP]
+
+
 # {{{ Op descriptor
 
-class Op(object):
+class Op(Record):
     """A descriptor for a type of arithmetic operation.
 
     .. attribute:: dtype
@@ -470,39 +498,49 @@ class Op(object):
     .. attribute:: name
 
        A :class:`str` that specifies the kind of arithmetic operation as
-       *add*, *sub*, *mul*, *div*, *pow*, *shift*, *bw* (bitwise), etc.
+       *add*, *mul*, *div*, *pow*, *shift*, *bw* (bitwise), etc.
 
-    """
+    .. attribute:: count_granularity
 
-    # FIXME: This could be done much more briefly by inheriting from Record.
+       A :class:`str` that specifies whether this operation should be counted
+       once per *work-item*, *sub-group*, or *work-group*. The granularities
+       allowed can be found in :class:`CountGranularity`, and may be accessed,
+       e.g., as ``CountGranularity.WORKITEM``. A work-item is a single instance
+       of computation executing on a single processor (think 'thread'), a
+       collection of which may be grouped together into a work-group. Each
+       work-group executes on a single compute unit with all work-items within
+       the work-group sharing local memory. A sub-group is an
+       implementation-dependent grouping of work-items within a work-group,
+       analagous to an NVIDIA CUDA warp.
 
-    def __init__(self, dtype=None, name=None):
-        self.name = name
+    """
+
+    def __init__(self, dtype=None, name=None, count_granularity=None):
+        if count_granularity not in CountGranularity.ALL+[None]:
+            raise ValueError("Op.__init__: count_granularity '%s' is "
+                    "not allowed. count_granularity options: %s"
+                    % (count_granularity, CountGranularity.ALL+[None]))
         if dtype is None:
-            self.dtype = dtype
+            Record.__init__(self, dtype=dtype, name=name,
+                            count_granularity=count_granularity)
         else:
             from loopy.types import to_loopy_type
-            self.dtype = to_loopy_type(dtype)
-
-    def __eq__(self, other):
-        return isinstance(other, Op) and (
-                (self.dtype is None or other.dtype is None or
-                 self.dtype == other.dtype) and
-                (self.name is None or other.name is None or
-                 self.name == other.name))
+            Record.__init__(self, dtype=to_loopy_type(dtype), name=name,
+                            count_granularity=count_granularity)
 
     def __hash__(self):
         return hash(str(self))
 
     def __repr__(self):
-        return "Op(%s, %s)" % (self.dtype, self.name)
+        # Record.__repr__ overridden for consistent ordering and conciseness
+        return "Op(%s, %s, %s)" % (self.dtype, self.name, self.count_granularity)
 
 # }}}
 
 
 # {{{ MemAccess descriptor
 
-class MemAccess(object):
+class MemAccess(Record):
     """A descriptor for a type of memory access.
 
     .. attribute:: mtype
@@ -517,8 +555,8 @@ class MemAccess(object):
 
     .. attribute:: stride
 
-       An :class:`int` that specifies stride of the memory access. A stride of 0
-       indicates a uniform access (i.e. all threads access the same item).
+       An :class:`int` that specifies stride of the memory access. A stride of
+       0 indicates a uniform access (i.e. all work-items access the same item).
 
     .. attribute:: direction
 
@@ -530,21 +568,23 @@ class MemAccess(object):
        A :class:`str` that specifies the variable name of the data
        accessed.
 
-    """
+    .. attribute:: count_granularity
+
+       A :class:`str` that specifies whether this operation should be counted
+       once per *work-item*, *sub-group*, or *work-group*. The granularities
+       allowed can be found in :class:`CountGranularity`, and may be accessed,
+       e.g., as ``CountGranularity.WORKITEM``. A work-item is a single instance
+       of computation executing on a single processor (think 'thread'), a
+       collection of which may be grouped together into a work-group. Each
+       work-group executes on a single compute unit with all work-items within
+       the work-group sharing local memory. A sub-group is an
+       implementation-dependent grouping of work-items within a work-group,
+       analagous to an NVIDIA CUDA warp.
 
-    # FIXME: This could be done much more briefly by inheriting from Record.
+    """
 
     def __init__(self, mtype=None, dtype=None, stride=None, direction=None,
-                 variable=None):
-        self.mtype = mtype
-        self.stride = stride
-        self.direction = direction
-        self.variable = variable
-        if dtype is None:
-            self.dtype = dtype
-        else:
-            from loopy.types import to_loopy_type
-            self.dtype = to_loopy_type(dtype)
+                 variable=None, count_granularity=None):
 
         #TODO currently giving all lmem access stride=None
         if (mtype == 'local') and (stride is not None):
@@ -556,55 +596,33 @@ class MemAccess(object):
             raise NotImplementedError("MemAccess: variable must be None when "
                                       "mtype is 'local'")
 
-    def copy(self, mtype=None, dtype=None, stride=None, direction=None,
-            variable=None):
-        return MemAccess(
-                mtype=mtype if mtype is not None else self.mtype,
-                dtype=dtype if dtype is not None else self.dtype,
-                stride=stride if stride is not None else self.stride,
-                direction=direction if direction is not None else self.direction,
-                variable=variable if variable is not None else self.variable,
-                )
-
-    def __eq__(self, other):
-        return isinstance(other, MemAccess) and (
-                (self.mtype is None or other.mtype is None or
-                 self.mtype == other.mtype) and
-                (self.dtype is None or other.dtype is None or
-                 self.dtype == other.dtype) and
-                (self.stride is None or other.stride is None or
-                 self.stride == other.stride) and
-                (self.direction is None or other.direction is None or
-                 self.direction == other.direction) and
-                (self.variable is None or other.variable is None or
-                 self.variable == other.variable))
+        if count_granularity not in CountGranularity.ALL+[None]:
+            raise ValueError("Op.__init__: count_granularity '%s' is "
+                    "not allowed. count_granularity options: %s"
+                    % (count_granularity, CountGranularity.ALL+[None]))
+
+        if dtype is None:
+            Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride,
+                            direction=direction, variable=variable,
+                            count_granularity=count_granularity)
+        else:
+            from loopy.types import to_loopy_type
+            Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype),
+                            stride=stride, direction=direction, variable=variable,
+                            count_granularity=count_granularity)
 
     def __hash__(self):
         return hash(str(self))
 
     def __repr__(self):
-        if self.mtype is None:
-            mtype = 'None'
-        else:
-            mtype = self.mtype
-        if self.dtype is None:
-            dtype = 'None'
-        else:
-            dtype = str(self.dtype)
-        if self.stride is None:
-            stride = 'None'
-        else:
-            stride = str(self.stride)
-        if self.direction is None:
-            direction = 'None'
-        else:
-            direction = self.direction
-        if self.variable is None:
-            variable = 'None'
-        else:
-            variable = self.variable
-        return "MemAccess(" + mtype + ", " + dtype + ", " + stride + ", " \
-               + direction + ", " + variable + ")"
+        # Record.__repr__ overridden for consistent ordering and conciseness
+        return "MemAccess(%s, %s, %s, %s, %s, %s)" % (
+            self.mtype,
+            self.dtype,
+            self.stride,
+            self.direction,
+            self.variable,
+            self.count_granularity)
 
 # }}}
 
@@ -687,7 +705,8 @@ class ExpressionOpCounter(CounterBase):
     def map_call(self, expr):
         return ToCountMap(
                     {Op(dtype=self.type_inf(expr),
-                        name='func:'+str(expr.function)): 1}
+                        name='func:'+str(expr.function),
+                        count_granularity=CountGranularity.WORKITEM): 1}
                     ) + self.rec(expr.parameters)
 
     def map_subscript(self, expr):
@@ -697,20 +716,28 @@ class ExpressionOpCounter(CounterBase):
         assert expr.children
         return ToCountMap(
                     {Op(dtype=self.type_inf(expr),
-                        name='add'): len(expr.children)-1}
+                        name='add',
+                        count_granularity=CountGranularity.WORKITEM):
+                     len(expr.children)-1}
                     ) + sum(self.rec(child) for child in expr.children)
 
     def map_product(self, expr):
         from pymbolic.primitives import is_zero
         assert expr.children
-        return sum(ToCountMap({Op(dtype=self.type_inf(expr), name='mul'): 1})
+        return sum(ToCountMap({Op(dtype=self.type_inf(expr),
+                                  name='mul',
+                                  count_granularity=CountGranularity.WORKITEM): 1})
                    + self.rec(child)
                    for child in expr.children
                    if not is_zero(child + 1)) + \
-                   ToCountMap({Op(dtype=self.type_inf(expr), name='mul'): -1})
+                   ToCountMap({Op(dtype=self.type_inf(expr),
+                                  name='mul',
+                                  count_granularity=CountGranularity.WORKITEM): -1})
 
     def map_quotient(self, expr, *args):
-        return ToCountMap({Op(dtype=self.type_inf(expr), name='div'): 1}) \
+        return ToCountMap({Op(dtype=self.type_inf(expr),
+                              name='div',
+                              count_granularity=CountGranularity.WORKITEM): 1}) \
                                 + self.rec(expr.numerator) \
                                 + self.rec(expr.denominator)
 
@@ -718,23 +745,31 @@ class ExpressionOpCounter(CounterBase):
     map_remainder = map_quotient
 
     def map_power(self, expr):
-        return ToCountMap({Op(dtype=self.type_inf(expr), name='pow'): 1}) \
+        return ToCountMap({Op(dtype=self.type_inf(expr),
+                              name='pow',
+                              count_granularity=CountGranularity.WORKITEM): 1}) \
                                 + self.rec(expr.base) \
                                 + self.rec(expr.exponent)
 
     def map_left_shift(self, expr):
-        return ToCountMap({Op(dtype=self.type_inf(expr), name='shift'): 1}) \
+        return ToCountMap({Op(dtype=self.type_inf(expr),
+                              name='shift',
+                              count_granularity=CountGranularity.WORKITEM): 1}) \
                                 + self.rec(expr.shiftee) \
                                 + self.rec(expr.shift)
 
     map_right_shift = map_left_shift
 
     def map_bitwise_not(self, expr):
-        return ToCountMap({Op(dtype=self.type_inf(expr), name='bw'): 1}) \
+        return ToCountMap({Op(dtype=self.type_inf(expr),
+                              name='bw',
+                              count_granularity=CountGranularity.WORKITEM): 1}) \
                                 + self.rec(expr.child)
 
     def map_bitwise_or(self, expr):
-        return ToCountMap({Op(dtype=self.type_inf(expr), name='bw'):
+        return ToCountMap({Op(dtype=self.type_inf(expr),
+                              name='bw',
+                              count_granularity=CountGranularity.WORKITEM):
                            len(expr.children)-1}) \
                                 + sum(self.rec(child) for child in expr.children)
 
@@ -756,7 +791,9 @@ class ExpressionOpCounter(CounterBase):
                + self.rec(expr.else_)
 
     def map_min(self, expr):
-        return ToCountMap({Op(dtype=self.type_inf(expr), name='maxmin'):
+        return ToCountMap({Op(dtype=self.type_inf(expr),
+                              name='maxmin',
+                              count_granularity=CountGranularity.WORKITEM):
                            len(expr.children)-1}) \
                + sum(self.rec(child) for child in expr.children)
 
@@ -797,7 +834,8 @@ class LocalMemAccessCounter(MemAccessCounter):
             array = self.knl.temporary_variables[name]
             if isinstance(array, TemporaryVariable) and (
                     array.scope == temp_var_scope.LOCAL):
-                sub_map[MemAccess(mtype='local', dtype=dtype)] = 1
+                sub_map[MemAccess(mtype='local', dtype=dtype,
+                                  count_granularity=CountGranularity.WORKITEM)] = 1
         return sub_map
 
     def map_variable(self, expr):
@@ -833,7 +871,8 @@ class GlobalMemAccessCounter(MemAccessCounter):
 
         return ToCountMap({MemAccess(mtype='global',
                                      dtype=self.type_inf(expr), stride=0,
-                                     variable=name): 1}
+                                     variable=name,
+                                     count_granularity=CountGranularity.WORKITEM): 1}
                           ) + self.rec(expr.index)
 
     def map_subscript(self, expr):
@@ -870,9 +909,11 @@ class GlobalMemAccessCounter(MemAccessCounter):
 
         if not local_id_found:
             # count as uniform access
-            return ToCountMap({MemAccess(mtype='global',
-                                         dtype=self.type_inf(expr), stride=0,
-                                         variable=name): 1}
+            return ToCountMap({MemAccess(
+                                mtype='global',
+                                dtype=self.type_inf(expr), stride=0,
+                                variable=name,
+                                count_granularity=CountGranularity.SUBGROUP): 1}
                               ) + self.rec(expr.index)
 
         if min_tag_axis != 0:
@@ -880,9 +921,11 @@ class GlobalMemAccessCounter(MemAccessCounter):
                              "GlobalSubscriptCounter: Memory access minimum "
                              "tag axis %d != 0, stride unknown, using "
                              "sys.maxsize." % (min_tag_axis))
-            return ToCountMap({MemAccess(mtype='global',
-                                         dtype=self.type_inf(expr),
-                                         stride=sys.maxsize, variable=name): 1}
+            return ToCountMap({MemAccess(
+                                mtype='global',
+                                dtype=self.type_inf(expr),
+                                stride=sys.maxsize, variable=name,
+                                count_granularity=CountGranularity.WORKITEM): 1}
                               ) + self.rec(expr.index)
 
         # get local_id associated with minimum tag axis
@@ -926,8 +969,16 @@ class GlobalMemAccessCounter(MemAccessCounter):
 
             total_stride += stride*coeff_min_lid
 
-        return ToCountMap({MemAccess(mtype='global', dtype=self.type_inf(expr),
-                                     stride=total_stride, variable=name): 1}
+        count_granularity = CountGranularity.WORKITEM if total_stride is not 0 \
+                                else CountGranularity.SUBGROUP
+
+        return ToCountMap({MemAccess(
+                            mtype='global',
+                            dtype=self.type_inf(expr),
+                            stride=total_stride,
+                            variable=name,
+                            count_granularity=count_granularity
+                            ): 1}
                           ) + self.rec(expr.index)
 
 # }}}
@@ -1144,6 +1195,7 @@ def get_unused_hw_axes_factor(knl, insn, disregard_local_axes, space=None):
 
 
 def count_insn_runs(knl, insn, count_redundant_work, disregard_local_axes=False):
+
     insn_inames = knl.insn_inames(insn)
 
     if disregard_local_axes:
@@ -1173,21 +1225,34 @@ def count_insn_runs(knl, insn, count_redundant_work, disregard_local_axes=False)
 
 # {{{ get_op_map
 
-def get_op_map(knl, numpy_types=True, count_redundant_work=False):
+def get_op_map(knl, numpy_types=True, count_redundant_work=False,
+               subgroup_size=None):
 
     """Count the number of operations in a loopy kernel.
 
     :arg knl: A :class:`loopy.LoopKernel` whose operations are to be counted.
 
-    :arg numpy_types: A :class:`bool` specifying whether the types
-         in the returned mapping should be numpy types
-         instead of :class:`loopy.LoopyType`.
+    :arg numpy_types: A :class:`bool` specifying whether the types in the
+        returned mapping should be numpy types instead of
+        :class:`loopy.LoopyType`.
 
     :arg count_redundant_work: Based on usage of hardware axes or other
         specifics, a kernel may perform work redundantly. This :class:`bool`
         flag indicates whether this work should be included in the count.
-        (Likely desirable for performance modeling, but undesirable for
-        code optimization.)
+        (Likely desirable for performance modeling, but undesirable for code
+        optimization.)
+
+    :arg subgroup_size: (currently unused) An :class:`int`, :class:`str`
+        ``'guess'``, or *None* that specifies the sub-group size. An OpenCL
+        sub-group is an implementation-dependent grouping of work-items within
+        a work-group, analagous to an NVIDIA CUDA warp. subgroup_size is used,
+        e.g., when counting a :class:`MemAccess` whose count_granularity
+        specifies that it should only be counted once per sub-group. If set to
+        *None* an attempt to find the sub-group size using the device will be
+        made, if this fails an error will be raised. If a :class:`str`
+        ``'guess'`` is passed as the subgroup_size, get_mem_access_map will
+        attempt to find the sub-group size using the device and, if
+        unsuccessful, will make a wild guess.
 
     :return: A :class:`ToCountMap` of **{** :class:`Op` **:**
         :class:`islpy.PwQPolynomial` **}**.
@@ -1205,10 +1270,16 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False):
 
         op_map = get_op_map(knl)
         params = {'n': 512, 'm': 256, 'l': 128}
-        f32add = op_map[Op(np.float32, 'add')].eval_with_dict(params)
-        f32mul = op_map[Op(np.float32, 'mul')].eval_with_dict(params)
+        f32add = op_map[Op(np.float32,
+                           'add',
+                           count_granularity=CountGranularity.WORKITEM)
+                       ].eval_with_dict(params)
+        f32mul = op_map[Op(np.float32,
+                           'mul',
+                           count_granularity=CountGranularity.WORKITEM)
+                       ].eval_with_dict(params)
 
-        # (now use these counts to predict performance)
+        # (now use these counts to, e.g., predict performance)
 
     """
 
@@ -1234,26 +1305,48 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False):
                     % type(insn).__name__)
 
     if numpy_types:
-        op_map.count_map = dict((Op(dtype=op.dtype.numpy_dtype, name=op.name),
-                                 count)
-                for op, count in six.iteritems(op_map.count_map))
-
-    return op_map
+        return ToCountMap(
+                    init_dict=dict(
+                        (Op(
+                            dtype=op.dtype.numpy_dtype,
+                            name=op.name,
+                            count_granularity=op.count_granularity),
+                        ct)
+                        for op, ct in six.iteritems(op_map.count_map)),
+                    val_type=op_map.val_type
+                    )
+    else:
+        return op_map
 
 # }}}
 
 
+def _find_subgroup_size_for_knl(knl):
+    from loopy.target.pyopencl import PyOpenCLTarget
+    if isinstance(knl.target, PyOpenCLTarget) and knl.target.device is not None:
+        from pyopencl.characterize import get_simd_group_size
+        subgroup_size_guess = get_simd_group_size(knl.target.device, None)
+        warn_with_kernel(knl, "getting_subgroup_size_from_device",
+                         "Device: %s. Using sub-group size given by "
+                         "pyopencl.characterize.get_simd_group_size(): %d"
+                         % (knl.target.device, subgroup_size_guess))
+        return subgroup_size_guess
+    else:
+        return None
+
+
 # {{{ get_mem_access_map
 
-def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False):
+def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
+                       subgroup_size=None):
     """Count the number of memory accesses in a loopy kernel.
 
     :arg knl: A :class:`loopy.LoopKernel` whose memory accesses are to be
         counted.
 
-    :arg numpy_types: A :class:`bool` specifying whether the types
-        in the returned mapping should be numpy types
-        instead of :class:`loopy.LoopyType`.
+    :arg numpy_types: A :class:`bool` specifying whether the types in the
+        returned mapping should be numpy types instead of
+        :class:`loopy.LoopyType`.
 
     :arg count_redundant_work: Based on usage of hardware axes or other
         specifics, a kernel may perform work redundantly. This :class:`bool`
@@ -1261,15 +1354,27 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False):
         (Likely desirable for performance modeling, but undesirable for
         code optimization.)
 
+    :arg subgroup_size: An :class:`int`, :class:`str` ``'guess'``, or
+        *None* that specifies the sub-group size. An OpenCL sub-group is an
+        implementation-dependent grouping of work-items within a work-group,
+        analagous to an NVIDIA CUDA warp. subgroup_size is used, e.g., when
+        counting a :class:`MemAccess` whose count_granularity specifies that it
+        should only be counted once per sub-group. If set to *None* an attempt
+        to find the sub-group size using the device will be made, if this fails
+        an error will be raised. If a :class:`str` ``'guess'`` is passed as
+        the subgroup_size, get_mem_access_map will attempt to find the
+        sub-group size using the device and, if unsuccessful, will make a wild
+        guess.
+
     :return: A :class:`ToCountMap` of **{** :class:`MemAccess` **:**
         :class:`islpy.PwQPolynomial` **}**.
 
-        - The :class:`MemAccess` specifies the characteristics of the
-          memory access.
+        - The :class:`MemAccess` specifies the characteristics of the memory
+          access.
 
-        - The :class:`islpy.PwQPolynomial` holds the number of memory
-          accesses with the characteristics specified in the key (in terms
-          of the :class:`loopy.LoopKernel` *inames*).
+        - The :class:`islpy.PwQPolynomial` holds the number of memory accesses
+          with the characteristics specified in the key (in terms of the
+          :class:`loopy.LoopKernel` *inames*).
 
     Example usage::
 
@@ -1278,48 +1383,132 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False):
         params = {'n': 512, 'm': 256, 'l': 128}
         mem_map = get_mem_access_map(knl)
 
-        f32_s1_g_ld_a = mem_map[MemAccess(mtype='global',
-                                          dtype=np.float32,
-                                          stride=1,
-                                          direction='load',
-                                          variable='a')
+        f32_s1_g_ld_a = mem_map[MemAccess(
+                                    mtype='global',
+                                    dtype=np.float32,
+                                    stride=1,
+                                    direction='load',
+                                    variable='a',
+                                    count_granularity=CountGranularity.WORKITEM)
                                ].eval_with_dict(params)
-        f32_s1_g_st_a = mem_map[MemAccess(mtype='global',
-                                          dtype=np.float32,
-                                          stride=1,
-                                          direction='store',
-                                          variable='a')
+        f32_s1_g_st_a = mem_map[MemAccess(
+                                    mtype='global',
+                                    dtype=np.float32,
+                                    stride=1,
+                                    direction='store',
+                                    variable='a',
+                                    count_granularity=CountGranularity.WORKITEM)
                                ].eval_with_dict(params)
-        f32_s1_l_ld_x = mem_map[MemAccess(mtype='local',
-                                          dtype=np.float32,
-                                          stride=1,
-                                          direction='load',
-                                          variable='x')
+        f32_s1_l_ld_x = mem_map[MemAccess(
+                                    mtype='local',
+                                    dtype=np.float32,
+                                    stride=1,
+                                    direction='load',
+                                    variable='x',
+                                    count_granularity=CountGranularity.WORKITEM)
                                ].eval_with_dict(params)
-        f32_s1_l_st_x = mem_map[MemAccess(mtype='local',
-                                          dtype=np.float32,
-                                          stride=1,
-                                          direction='store',
-                                          variable='x')
+        f32_s1_l_st_x = mem_map[MemAccess(
+                                    mtype='local',
+                                    dtype=np.float32,
+                                    stride=1,
+                                    direction='store',
+                                    variable='x',
+                                    count_granularity=CountGranularity.WORKITEM)
                                ].eval_with_dict(params)
 
-        # (now use these counts to predict performance)
+        # (now use these counts to, e.g., predict performance)
 
     """
     from loopy.preprocess import preprocess_kernel, infer_unknown_types
 
+    if not isinstance(subgroup_size, int):
+        # try to find subgroup_size
+        subgroup_size_guess = _find_subgroup_size_for_knl(knl)
+
+        if subgroup_size is None:
+            if subgroup_size_guess is None:
+                # 'guess' was not passed and either no target device found
+                # or get_simd_group_size returned None
+                raise ValueError("No sub-group size passed, no target device found. "
+                                 "Either (1) pass integer value for subgroup_size, "
+                                 "(2) ensure that kernel.target is PyOpenClTarget "
+                                 "and kernel.target.device is set, or (3) pass "
+                                 "subgroup_size='guess' and hope for the best.")
+            else:
+                subgroup_size = subgroup_size_guess
+
+        elif subgroup_size == 'guess':
+            if subgroup_size_guess is None:
+                # unable to get subgroup_size from device, so guess
+                subgroup_size = 32
+                warn_with_kernel(knl, "get_mem_access_map_guessing_subgroup_size",
+                                 "get_mem_access_map: 'guess' sub-group size "
+                                 "passed, no target device found, wildly guessing "
+                                 "that sub-group size is %d." % (subgroup_size))
+            else:
+                subgroup_size = subgroup_size_guess
+        else:
+            raise ValueError("Invalid value for subgroup_size: %s. subgroup_size "
+                             "must be integer, 'guess', or, if you're feeling "
+                             "lucky, None." % (subgroup_size))
+
     class CacheHolder(object):
         pass
 
     cache_holder = CacheHolder()
+    from pytools import memoize_in
 
     @memoize_in(cache_holder, "insn_count")
-    def get_insn_count(knl, insn_id, uniform=False):
+    def get_insn_count(knl, insn_id, count_granularity=CountGranularity.WORKITEM):
         insn = knl.id_to_insn[insn_id]
-        return count_insn_runs(
-                knl, insn, disregard_local_axes=uniform,
+
+        if count_granularity is None:
+            warn_with_kernel(knl, "get_insn_count_assumes_granularity",
+                             "get_insn_count: No count granularity passed for "
+                             "MemAccess, assuming %s granularity."
+                             % (CountGranularity.WORKITEM))
+            count_granularity == CountGranularity.WORKITEM
+
+        if count_granularity == CountGranularity.WORKITEM:
+            return count_insn_runs(
+                knl, insn, count_redundant_work=count_redundant_work,
+                disregard_local_axes=False)
+
+        ct_disregard_local = count_insn_runs(
+                knl, insn, disregard_local_axes=True,
                 count_redundant_work=count_redundant_work)
 
+        if count_granularity == CountGranularity.WORKGROUP:
+            return ct_disregard_local
+        elif count_granularity == CountGranularity.SUBGROUP:
+            # get the group size
+            from loopy.symbolic import aff_to_expr
+            _, local_size = knl.get_grid_size_upper_bounds()
+            workgroup_size = 1
+            if local_size:
+                for size in local_size:
+                    s = aff_to_expr(size)
+                    if not isinstance(s, int):
+                        raise LoopyError("Cannot count insn with %s granularity, "
+                                         "work-group size is not integer: %s"
+                                         % (CountGranularity.SUBGROUP, local_size))
+                    workgroup_size *= s
+
+            warn_with_kernel(knl, "insn_count_subgroups_upper_bound",
+                    "get_insn_count: when counting instruction %s with "
+                    "count_granularity=%s, using upper bound for work-group size "
+                    "(%d work-items) to compute sub-groups per work-group. When "
+                    "multiple device programs present, actual sub-group count may be"
+                    "lower." % (insn_id, CountGranularity.SUBGROUP, workgroup_size))
+
+            from pytools import div_ceil
+            return ct_disregard_local*div_ceil(workgroup_size, subgroup_size)
+        else:
+            # this should not happen since this is enforced in MemAccess
+            raise ValueError("get_insn_count: count_granularity '%s' is"
+                    "not allowed. count_granularity options: %s"
+                    % (count_granularity, CountGranularity.ALL+[None]))
+
     knl = infer_unknown_types(knl, expect_completion=True)
     knl = preprocess_kernel(knl)
 
@@ -1342,26 +1531,23 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False):
                     direction="store")
 
             # FIXME: (!!!!) for now, don't count writes to local mem
+            # (^this is updated in a branch that will be merged soon)
 
             # use count excluding local index tags for uniform accesses
             for key, val in six.iteritems(access_expr.count_map):
-                is_uniform = (key.mtype == 'global' and
-                        isinstance(key.stride, int) and
-                        key.stride == 0)
+
                 access_map = (
                         access_map
                         + ToCountMap({key: val})
-                        * get_insn_count(knl, insn.id, is_uniform))
+                        * get_insn_count(knl, insn.id, key.count_granularity))
                 #currently not counting stride of local mem access
 
             for key, val in six.iteritems(access_assignee_g.count_map):
-                is_uniform = (key.mtype == 'global' and
-                        isinstance(key.stride, int) and
-                        key.stride == 0)
+
                 access_map = (
                         access_map
                         + ToCountMap({key: val})
-                        * get_insn_count(knl, insn.id, is_uniform))
+                        * get_insn_count(knl, insn.id, key.count_granularity))
                 # for now, don't count writes to local mem
         elif isinstance(insn, (NoOpInstruction, BarrierInstruction)):
             pass
@@ -1370,35 +1556,52 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False):
                     % type(insn).__name__)
 
     if numpy_types:
-        # FIXME: Don't modify in-place
-        access_map.count_map = dict((MemAccess(mtype=mem_access.mtype,
-                                             dtype=mem_access.dtype.numpy_dtype,
-                                             stride=mem_access.stride,
-                                             direction=mem_access.direction,
-                                             variable=mem_access.variable),
-                                  count)
-                      for mem_access, count in six.iteritems(access_map.count_map))
-
-    return access_map
+        return ToCountMap(
+                    init_dict=dict(
+                        (MemAccess(
+                            mtype=mem_access.mtype,
+                            dtype=mem_access.dtype.numpy_dtype,
+                            stride=mem_access.stride,
+                            direction=mem_access.direction,
+                            variable=mem_access.variable,
+                            count_granularity=mem_access.count_granularity),
+                        ct)
+                        for mem_access, ct in six.iteritems(access_map.count_map)),
+                    val_type=access_map.val_type
+                    )
+    else:
+        return access_map
 
 # }}}
 
 
 # {{{ get_synchronization_map
 
-def get_synchronization_map(knl):
+def get_synchronization_map(knl, subgroup_size=None):
 
-    """Count the number of synchronization events each thread encounters in a
-    loopy kernel.
+    """Count the number of synchronization events each work-item encounters in
+    a loopy kernel.
 
     :arg knl: A :class:`loopy.LoopKernel` whose barriers are to be counted.
 
-    :return: A dictionary mapping each type of synchronization event to a
-            :class:`islpy.PwQPolynomial` holding the number of events per
-            thread.
-
-            Possible keys include ``barrier_local``, ``barrier_global``
-            (if supported by the target) and ``kernel_launch``.
+    :arg subgroup_size: (currently unused) An :class:`int`, :class:`str`
+        ``'guess'``, or *None* that specifies the sub-group size. An OpenCL
+        sub-group is an implementation-dependent grouping of work-items within
+        a work-group, analagous to an NVIDIA CUDA warp. subgroup_size is used,
+        e.g., when counting a :class:`MemAccess` whose count_granularity
+        specifies that it should only be counted once per sub-group. If set to
+        *None* an attempt to find the sub-group size using the device will be
+        made, if this fails an error will be raised. If a :class:`str`
+        ``'guess'`` is passed as the subgroup_size, get_mem_access_map will
+        attempt to find the sub-group size using the device and, if
+        unsuccessful, will make a wild guess.
+
+    :return: A dictionary mapping each type of synchronization event to an
+        :class:`islpy.PwQPolynomial` holding the number of events per
+        work-item.
+
+        Possible keys include ``barrier_local``, ``barrier_global``
+        (if supported by the target) and ``kernel_launch``.
 
     Example usage::
 
@@ -1408,7 +1611,7 @@ def get_synchronization_map(knl):
         params = {'n': 512, 'm': 256, 'l': 128}
         barrier_ct = sync_map['barrier_local'].eval_with_dict(params)
 
-        # (now use this count to predict performance)
+        # (now use this count to, e.g., predict performance)
 
     """
 
@@ -1467,14 +1670,14 @@ def get_synchronization_map(knl):
 # {{{ gather_access_footprints
 
 def gather_access_footprints(kernel, ignore_uncountable=False):
-    """Return a dictionary mapping ``(var_name, direction)``
-    to :class:`islpy.Set` instances capturing which indices
-    of each the array *var_name* are read/written (where
-    *direction* is either ``read`` or ``write``.
-
-    :arg ignore_uncountable: If *False*, an error will be raised for
-        accesses on which the footprint cannot be determined (e.g.
-        data-dependent or nonlinear indices)
+    """Return a dictionary mapping ``(var_name, direction)`` to
+    :class:`islpy.Set` instances capturing which indices of each the array
+    *var_name* are read/written (where *direction* is either ``read`` or
+    ``write``.
+
+    :arg ignore_uncountable: If *False*, an error will be raised for accesses
+        on which the footprint cannot be determined (e.g. data-dependent or
+        nonlinear indices)
     """
 
     from loopy.preprocess import preprocess_kernel, infer_unknown_types
@@ -1526,9 +1729,9 @@ def gather_access_footprint_bytes(kernel, ignore_uncountable=False):
     read/written (where *direction* is either ``read`` or ``write`` on array
     *var_name*
 
-    :arg ignore_uncountable: If *True*, an error will be raised for
-        accesses on which the footprint cannot be determined (e.g.
-        data-dependent or nonlinear indices)
+    :arg ignore_uncountable: If *True*, an error will be raised for accesses on
+        which the footprint cannot be determined (e.g. data-dependent or
+        nonlinear indices)
     """
 
     from loopy.preprocess import preprocess_kernel, infer_unknown_types
@@ -1604,10 +1807,11 @@ def get_gmem_access_poly(knl):
 
 
 def get_synchronization_poly(knl):
-    """Count the number of synchronization events each thread encounters in a
-    loopy kernel.
+    """Count the number of synchronization events each work-item encounters in
+    a loopy kernel.
 
-    get_synchronization_poly is deprecated. Use get_synchronization_map instead.
+    get_synchronization_poly is deprecated. Use get_synchronization_map
+    instead.
 
     """
     warn_with_kernel(knl, "deprecated_get_synchronization_poly",
diff --git a/setup.py b/setup.py
index bd94ea7e7e387709684adbc43a5753f1395df2f7..ffa5e2fea79cbe82fe1d0c1c17137833620fa7d5 100644
--- a/setup.py
+++ b/setup.py
@@ -62,7 +62,7 @@ setup(name="loo.py",
           },
 
       dependency_links=[
-          "hg+https://bitbucket.org/inducer/f2py#egg=f2py==0.3.1"
+          "git+https://github.com/pearu/f2py.git"
           ],
 
       scripts=["bin/loopy"],
diff --git a/test/test_numa_diff.py b/test/test_numa_diff.py
index a287ad59d7697eef79336678afa831e73b81784b..216f7f637eb06ad9dbeff76d958a04869c8e3457 100644
--- a/test/test_numa_diff.py
+++ b/test/test_numa_diff.py
@@ -233,7 +233,7 @@ def test_gnuma_horiz_kernel(ctx_factory, ilp_multiple, Nq, opt_level):  # noqa
         print(lp.stringify_stats_mapping(op_map))
 
         print("MEM")
-        gmem_map = lp.get_mem_access_map(hsv).to_bytes()
+        gmem_map = lp.get_mem_access_map(hsv, subgroup_size=32).to_bytes()
         print(lp.stringify_stats_mapping(gmem_map))
 
     hsv = lp.set_options(hsv, cl_build_options=[
diff --git a/test/test_statistics.py b/test/test_statistics.py
index e4232e613c569cb4a0d66b500a981643bf5bac05..b9c7185c21af4782af8fb284e72ac6041d5f98da 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -30,6 +30,8 @@ from pyopencl.tools import (  # noqa
 import loopy as lp
 from loopy.types import to_loopy_type
 import numpy as np
+from pytools import div_ceil
+from loopy.statistics import CountGranularity as CG
 
 from pymbolic.primitives import Variable
 
@@ -57,11 +59,13 @@ def test_op_counter_basic():
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32add = op_map[lp.Op(np.float32, 'add')].eval_with_dict(params)
-    f32mul = op_map[lp.Op(np.float32, 'mul')].eval_with_dict(params)
-    f32div = op_map[lp.Op(np.float32, 'div')].eval_with_dict(params)
-    f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul')].eval_with_dict(params)
-    i32add = op_map[lp.Op(np.dtype(np.int32), 'add')].eval_with_dict(params)
+    f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params)
+    f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params)
+    f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params)
+    f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.WORKITEM)
+                    ].eval_with_dict(params)
+    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM)
+                    ].eval_with_dict(params)
     assert f32add == f32mul == f32div == n*m*ell
     assert f64mul == n*m
     assert i32add == n*m*2
@@ -82,8 +86,9 @@ def test_op_counter_reduction():
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32add = op_map[lp.Op(np.float32, 'add')].eval_with_dict(params)
-    f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul')].eval_with_dict(params)
+    f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params)
+    f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.WORKITEM)
+                    ].eval_with_dict(params)
     assert f32add == f32mul == n*m*ell
 
     op_map_dtype = op_map.group_by('dtype')
@@ -111,10 +116,12 @@ def test_op_counter_logic():
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32mul = op_map[lp.Op(np.float32, 'mul')].eval_with_dict(params)
-    f64add = op_map[lp.Op(np.float64, 'add')].eval_with_dict(params)
-    f64div = op_map[lp.Op(np.dtype(np.float64), 'div')].eval_with_dict(params)
-    i32add = op_map[lp.Op(np.dtype(np.int32), 'add')].eval_with_dict(params)
+    f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params)
+    f64add = op_map[lp.Op(np.float64, 'add', CG.WORKITEM)].eval_with_dict(params)
+    f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.WORKITEM)
+                    ].eval_with_dict(params)
+    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM)
+                    ].eval_with_dict(params)
     assert f32mul == n*m
     assert f64div == 2*n*m  # TODO why?
     assert f64add == n*m
@@ -141,14 +148,18 @@ def test_op_counter_specialops():
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32mul = op_map[lp.Op(np.float32, 'mul')].eval_with_dict(params)
-    f32div = op_map[lp.Op(np.float32, 'div')].eval_with_dict(params)
-    f32add = op_map[lp.Op(np.float32, 'add')].eval_with_dict(params)
-    f64pow = op_map[lp.Op(np.float64, 'pow')].eval_with_dict(params)
-    f64add = op_map[lp.Op(np.dtype(np.float64), 'add')].eval_with_dict(params)
-    i32add = op_map[lp.Op(np.dtype(np.int32), 'add')].eval_with_dict(params)
-    f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt')].eval_with_dict(params)
-    f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin')].eval_with_dict(params)
+    f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params)
+    f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params)
+    f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params)
+    f64pow = op_map[lp.Op(np.float64, 'pow', CG.WORKITEM)].eval_with_dict(params)
+    f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.WORKITEM)
+                    ].eval_with_dict(params)
+    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM)
+                    ].eval_with_dict(params)
+    f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.WORKITEM)
+                    ].eval_with_dict(params)
+    f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin', CG.WORKITEM)
+                    ].eval_with_dict(params)
     assert f32div == 2*n*m*ell
     assert f32mul == f32add == n*m*ell
     assert f64add == 3*n*m
@@ -177,12 +188,16 @@ def test_op_counter_bitwise():
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    i32add = op_map[lp.Op(np.int32, 'add')].eval_with_dict(params)
-    i32bw = op_map[lp.Op(np.int32, 'bw')].eval_with_dict(params)
-    i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw')].eval_with_dict(params)
-    i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul')].eval_with_dict(params)
-    i64add = op_map[lp.Op(np.dtype(np.int64), 'add')].eval_with_dict(params)
-    i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift')].eval_with_dict(params)
+    i32add = op_map[lp.Op(np.int32, 'add', CG.WORKITEM)].eval_with_dict(params)
+    i32bw = op_map[lp.Op(np.int32, 'bw', CG.WORKITEM)].eval_with_dict(params)
+    i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.WORKITEM)
+                   ].eval_with_dict(params)
+    i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.WORKITEM)
+                    ].eval_with_dict(params)
+    i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.WORKITEM)
+                    ].eval_with_dict(params)
+    i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.WORKITEM)
+                      ].eval_with_dict(params)
     assert i32add == n*m+n*m*ell
     assert i32bw == 2*n*m*ell
     assert i64bw == 2*n*m
@@ -211,7 +226,10 @@ def test_op_counter_triangular_domain():
     else:
         expect_fallback = False
 
-    op_map = lp.get_op_map(knl, count_redundant_work=True)[lp.Op(np.float64, 'mul')]
+    op_map = lp.get_op_map(
+                    knl,
+                    count_redundant_work=True
+                    )[lp.Op(np.float64, 'mul', CG.WORKITEM)]
     value_dict = dict(m=13, n=200)
     flops = op_map.eval_with_dict(value_dict)
 
@@ -234,35 +252,55 @@ def test_mem_access_counter_basic():
             name="basic", assumptions="n,m,ell >= 1")
 
     knl = lp.add_and_infer_dtypes(knl,
-                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+                    dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+
+    subgroup_size = 32
+
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=subgroup_size)
+
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
+
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, subgroup_size)
+
     f32l = mem_map[lp.MemAccess('global', np.float32,
-                         stride=0, direction='load', variable='a')
+                         stride=0, direction='load', variable='a',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
     f32l += mem_map[lp.MemAccess('global', np.float32,
-                          stride=0, direction='load', variable='b')
+                         stride=0, direction='load', variable='b',
+                         count_granularity=CG.SUBGROUP)
                     ].eval_with_dict(params)
     f64l = mem_map[lp.MemAccess('global', np.float64,
-                         stride=0, direction='load', variable='g')
+                         stride=0, direction='load', variable='g',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
     f64l += mem_map[lp.MemAccess('global', np.float64,
-                          stride=0, direction='load', variable='h')
+                         stride=0, direction='load', variable='h',
+                         count_granularity=CG.SUBGROUP)
                     ].eval_with_dict(params)
-    assert f32l == 3*n*m*ell
-    assert f64l == 2*n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32l == (3*n*m*ell)*n_workgroups*subgroups_per_group
+    assert f64l == (2*n*m)*n_workgroups*subgroups_per_group
 
     f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32),
-                         stride=0, direction='store', variable='c')
+                         stride=0, direction='store', variable='c',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
     f64s = mem_map[lp.MemAccess('global', np.dtype(np.float64),
-                         stride=0, direction='store', variable='e')
+                         stride=0, direction='store', variable='e',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
-    assert f32s == n*m*ell
-    assert f64s == n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32s == (n*m*ell)*n_workgroups*subgroups_per_group
+    assert f64s == (n*m)*n_workgroups*subgroups_per_group
 
 
 def test_mem_access_counter_reduction():
@@ -275,23 +313,39 @@ def test_mem_access_counter_reduction():
             name="matmul", assumptions="n,m,ell >= 1")
 
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+
+    subgroup_size = 32
+
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=subgroup_size)
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
+
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, subgroup_size)
+
     f32l = mem_map[lp.MemAccess('global', np.float32,
-                         stride=0, direction='load', variable='a')
+                         stride=0, direction='load', variable='a',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
     f32l += mem_map[lp.MemAccess('global', np.float32,
-                          stride=0, direction='load', variable='b')
+                         stride=0, direction='load', variable='b',
+                         count_granularity=CG.SUBGROUP)
                     ].eval_with_dict(params)
-    assert f32l == 2*n*m*ell
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32l == (2*n*m*ell)*n_workgroups*subgroups_per_group
 
     f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32),
-                         stride=0, direction='store', variable='c')
+                         stride=0, direction='store', variable='c',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
-    assert f32s == n*ell
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32s == (n*ell)*n_workgroups*subgroups_per_group
 
     ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load']
                                  ).to_bytes().eval_and_sum(params)
@@ -315,12 +369,20 @@ def test_mem_access_counter_logic():
             name="logic", assumptions="n,m,ell >= 1")
 
     knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+
+    subgroup_size = 32
+
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=subgroup_size)
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
 
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, subgroup_size)
+
     reduced_map = mem_map.group_by('mtype', 'dtype', 'direction')
 
     f32_g_l = reduced_map[lp.MemAccess('global', to_loopy_type(np.float32),
@@ -332,9 +394,11 @@ def test_mem_access_counter_logic():
     f64_g_s = reduced_map[lp.MemAccess('global', to_loopy_type(np.float64),
                                        direction='store')
                           ].eval_with_dict(params)
-    assert f32_g_l == 2*n*m
-    assert f64_g_l == n*m
-    assert f64_g_s == n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32_g_l == (2*n*m)*n_workgroups*subgroups_per_group
+    assert f64_g_l == (n*m)*n_workgroups*subgroups_per_group
+    assert f64_g_s == (n*m)*n_workgroups*subgroups_per_group
 
 
 def test_mem_access_counter_specialops():
@@ -351,39 +415,60 @@ def test_mem_access_counter_specialops():
 
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32,
                                             g=np.float64, h=np.float64))
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+
+    subgroup_size = 32
+
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=subgroup_size)
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
+
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, subgroup_size)
+
     f32 = mem_map[lp.MemAccess('global', np.float32,
-                         stride=0, direction='load', variable='a')
+                         stride=0, direction='load', variable='a',
+                         count_granularity=CG.SUBGROUP)
                   ].eval_with_dict(params)
     f32 += mem_map[lp.MemAccess('global', np.float32,
-                          stride=0, direction='load', variable='b')
+                         stride=0, direction='load', variable='b',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
     f64 = mem_map[lp.MemAccess('global', np.dtype(np.float64),
-                         stride=0, direction='load', variable='g')
+                         stride=0, direction='load', variable='g',
+                         count_granularity=CG.SUBGROUP)
                   ].eval_with_dict(params)
     f64 += mem_map[lp.MemAccess('global', np.dtype(np.float64),
-                          stride=0, direction='load', variable='h')
+                         stride=0, direction='load', variable='h',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
-    assert f32 == 2*n*m*ell
-    assert f64 == 2*n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32 == (2*n*m*ell)*n_workgroups*subgroups_per_group
+    assert f64 == (2*n*m)*n_workgroups*subgroups_per_group
 
     f32 = mem_map[lp.MemAccess('global', np.float32,
-                         stride=0, direction='store', variable='c')
+                         stride=0, direction='store', variable='c',
+                         count_granularity=CG.SUBGROUP)
                   ].eval_with_dict(params)
     f64 = mem_map[lp.MemAccess('global', np.float64,
-                         stride=0, direction='store', variable='e')
+                         stride=0, direction='store', variable='e',
+                         count_granularity=CG.SUBGROUP)
                   ].eval_with_dict(params)
-    assert f32 == n*m*ell
-    assert f64 == n*m
 
-    filtered_map = mem_map.filter_by(direction=['load'], variable=['a', 'g'])
-    #tot = lp.eval_and_sum_polys(filtered_map, params)
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32 == (n*m*ell)*n_workgroups*subgroups_per_group
+    assert f64 == (n*m)*n_workgroups*subgroups_per_group
+
+    filtered_map = mem_map.filter_by(direction=['load'], variable=['a', 'g'],
+                         count_granularity=CG.SUBGROUP)
     tot = filtered_map.eval_and_sum(params)
-    assert tot == n*m*ell + n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert tot == (n*m*ell + n*m)*n_workgroups*subgroups_per_group
 
 
 def test_mem_access_counter_bitwise():
@@ -403,36 +488,53 @@ def test_mem_access_counter_bitwise():
                 a=np.int32, b=np.int32,
                 g=np.int32, h=np.int32))
 
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+    subgroup_size = 32
+
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=subgroup_size)
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
+
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, subgroup_size)
+
     i32 = mem_map[lp.MemAccess('global', np.int32,
-                         stride=0, direction='load', variable='a')
+                         stride=0, direction='load', variable='a',
+                         count_granularity=CG.SUBGROUP)
                   ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.int32,
-                          stride=0, direction='load', variable='b')
+                         stride=0, direction='load', variable='b',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.int32,
-                          stride=0, direction='load', variable='g')
+                         stride=0, direction='load', variable='g',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.dtype(np.int32),
-                          stride=0, direction='load', variable='h')
+                         stride=0, direction='load', variable='h',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
-    assert i32 == 4*n*m+2*n*m*ell
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert i32 == (4*n*m+2*n*m*ell)*n_workgroups*subgroups_per_group
 
     i32 = mem_map[lp.MemAccess('global', np.int32,
-                         stride=0, direction='store', variable='c')
+                         stride=0, direction='store', variable='c',
+                         count_granularity=CG.SUBGROUP)
                   ].eval_with_dict(params)
     i32 += mem_map[lp.MemAccess('global', np.int32,
-                          stride=0, direction='store', variable='e')
+                         stride=0, direction='store', variable='e',
+                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
-    assert i32 == n*m+n*m*ell
 
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert i32 == (n*m+n*m*ell)*n_workgroups*subgroups_per_group
 
-def test_mem_access_counter_mixed():
 
+def test_mem_access_counter_mixed():
     knl = lp.make_kernel(
             "[n,m,ell] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<ell}",
             [
@@ -442,48 +544,92 @@ def test_mem_access_counter_mixed():
             """
             ],
             name="mixed", assumptions="n,m,ell >= 1")
+
     knl = lp.add_and_infer_dtypes(knl, dict(
                 a=np.float32, b=np.float32, g=np.float64, h=np.float64,
                 x=np.float32))
-    bsize = 16
-    knl = lp.split_iname(knl, "j", bsize)
+
+    group_size_0 = 65
+    subgroup_size = 32
+
+    knl = lp.split_iname(knl, "j", group_size_0)
     knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"})
 
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)  # noqa
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
+
+    n_workgroups = div_ceil(ell, group_size_0)
+    group_size = group_size_0
+    subgroups_per_group = div_ceil(group_size, subgroup_size)
+
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=subgroup_size)
     f64uniform = mem_map[lp.MemAccess('global', np.float64,
-                                stride=0, direction='load', variable='g')
+                                stride=0, direction='load', variable='g',
+                                count_granularity=CG.SUBGROUP)
                          ].eval_with_dict(params)
     f64uniform += mem_map[lp.MemAccess('global', np.float64,
-                                 stride=0, direction='load', variable='h')
+                                stride=0, direction='load', variable='h',
+                                count_granularity=CG.SUBGROUP)
                           ].eval_with_dict(params)
     f32uniform = mem_map[lp.MemAccess('global', np.float32,
-                                stride=0, direction='load', variable='x')
+                                stride=0, direction='load', variable='x',
+                                count_granularity=CG.SUBGROUP)
                          ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.dtype(np.float32),
-                                  stride=Variable('m'), direction='load',
-                                  variable='a')
+                                stride=Variable('m'), direction='load',
+                                variable='a',
+                                count_granularity=CG.WORKITEM)
                            ].eval_with_dict(params)
     f32nonconsec += mem_map[lp.MemAccess('global', np.dtype(np.float32),
-                                   stride=Variable('m'), direction='load',
-                                   variable='b')
+                                stride=Variable('m'), direction='load',
+                                variable='b',
+                                count_granularity=CG.WORKITEM)
                             ].eval_with_dict(params)
-    assert f64uniform == 2*n*m*ell/bsize
-    assert f32uniform == n*m*ell/bsize
-    assert f32nonconsec == 3*n*m*ell
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f64uniform == (2*n*m)*n_workgroups*subgroups_per_group
+    assert f32uniform == (m*n)*n_workgroups*subgroups_per_group
+
+    expect_fallback = False
+    import islpy as isl
+    try:
+        isl.BasicSet.card
+    except AttributeError:
+        expect_fallback = True
+    else:
+        expect_fallback = False
+
+    if expect_fallback:
+        if ell < group_size_0:
+            assert f32nonconsec == 3*n*m*ell*n_workgroups
+        else:
+            assert f32nonconsec == 3*n*m*n_workgroups*group_size_0
+    else:
+        assert f32nonconsec == 3*n*m*ell
 
     f64uniform = mem_map[lp.MemAccess('global', np.float64,
-                                stride=0, direction='store', variable='e')
+                                stride=0, direction='store', variable='e',
+                                count_granularity=CG.SUBGROUP)
                          ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.float32,
-                                  stride=Variable('m'), direction='store',
-                                  variable='c')
+                                stride=Variable('m'), direction='store',
+                                variable='c',
+                                count_granularity=CG.WORKITEM)
                            ].eval_with_dict(params)
-    assert f64uniform == n*m*ell/bsize
-    assert f32nonconsec == n*m*ell
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f64uniform == m*n*n_workgroups*subgroups_per_group
+
+    if expect_fallback:
+        if ell < group_size_0:
+            assert f32nonconsec == n*m*ell*n_workgroups
+        else:
+            assert f32nonconsec == n*m*n_workgroups*group_size_0
+    else:
+        assert f32nonconsec == n*m*ell
 
 
 def test_mem_access_counter_nonconsec():
@@ -502,41 +648,81 @@ def test_mem_access_counter_nonconsec():
     knl = lp.split_iname(knl, "i", 16)
     knl = lp.tag_inames(knl, {"i_inner": "l.0", "i_outer": "g.0"})
 
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)  # noqa
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=32)  # noqa
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
     f64nonconsec = mem_map[lp.MemAccess('global', np.float64,
-                                  stride=Variable('m'), direction='load',
-                                  variable='g')
+                                stride=Variable('m'), direction='load',
+                                variable='g',
+                                count_granularity=CG.WORKITEM)
                            ].eval_with_dict(params)
     f64nonconsec += mem_map[lp.MemAccess('global', np.float64,
-                                   stride=Variable('m'), direction='load',
-                                   variable='h')
+                                stride=Variable('m'), direction='load',
+                                variable='h',
+                                count_granularity=CG.WORKITEM)
                             ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.dtype(np.float32),
-                                  stride=Variable('m')*Variable('ell'),
-                                  direction='load', variable='a')
+                                stride=Variable('m')*Variable('ell'),
+                                direction='load', variable='a',
+                                count_granularity=CG.WORKITEM)
                            ].eval_with_dict(params)
     f32nonconsec += mem_map[lp.MemAccess('global', np.dtype(np.float32),
-                                   stride=Variable('m')*Variable('ell'),
-                                   direction='load', variable='b')
+                                stride=Variable('m')*Variable('ell'),
+                                direction='load', variable='b',
+                                count_granularity=CG.WORKITEM)
                             ].eval_with_dict(params)
     assert f64nonconsec == 2*n*m
     assert f32nonconsec == 3*n*m*ell
 
     f64nonconsec = mem_map[lp.MemAccess('global', np.float64,
-                                  stride=Variable('m'), direction='store',
-                                  variable='e')
+                                stride=Variable('m'), direction='store',
+                                variable='e',
+                                count_granularity=CG.WORKITEM)
                            ].eval_with_dict(params)
     f32nonconsec = mem_map[lp.MemAccess('global', np.float32,
-                                  stride=Variable('m')*Variable('ell'),
-                                  direction='store', variable='c')
+                                stride=Variable('m')*Variable('ell'),
+                                direction='store', variable='c',
+                                count_granularity=CG.WORKITEM)
                            ].eval_with_dict(params)
     assert f64nonconsec == n*m
     assert f32nonconsec == n*m*ell
 
+    mem_map64 = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                      subgroup_size=64)
+    f64nonconsec = mem_map64[lp.MemAccess(
+                    'global',
+                    np.float64, stride=Variable('m'),
+                    direction='load', variable='g',
+                    count_granularity=CG.WORKITEM)
+                    ].eval_with_dict(params)
+    f64nonconsec += mem_map64[lp.MemAccess(
+                    'global',
+                    np.float64, stride=Variable('m'),
+                    direction='load', variable='h',
+                    count_granularity=CG.WORKITEM)
+                    ].eval_with_dict(params)
+    f32nonconsec = mem_map64[lp.MemAccess(
+                    'global',
+                    np.dtype(np.float32),
+                    stride=Variable('m')*Variable('ell'),
+                    direction='load',
+                    variable='a',
+                    count_granularity=CG.WORKITEM)
+                    ].eval_with_dict(params)
+    f32nonconsec += mem_map64[lp.MemAccess(
+                    'global',
+                    np.dtype(np.float32),
+                    stride=Variable('m')*Variable('ell'),
+                    direction='load',
+                    variable='b',
+                    count_granularity=CG.WORKITEM)
+                    ].eval_with_dict(params)
+    assert f64nonconsec == 2*n*m
+    assert f32nonconsec == 3*n*m*ell
+
 
 def test_mem_access_counter_consec():
 
@@ -553,37 +739,69 @@ def test_mem_access_counter_consec():
                 a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     knl = lp.tag_inames(knl, {"k": "l.0", "i": "g.0", "j": "g.1"})
 
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size='guess')
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
 
     f64consec = mem_map[lp.MemAccess('global', np.float64,
-                        stride=1, direction='load', variable='g')
+                        stride=1, direction='load', variable='g',
+                        count_granularity=CG.WORKITEM)
                         ].eval_with_dict(params)
     f64consec += mem_map[lp.MemAccess('global', np.float64,
-                        stride=1, direction='load', variable='h')
+                        stride=1, direction='load', variable='h',
+                        count_granularity=CG.WORKITEM)
                          ].eval_with_dict(params)
     f32consec = mem_map[lp.MemAccess('global', np.float32,
-                        stride=1, direction='load', variable='a')
+                        stride=1, direction='load', variable='a',
+                        count_granularity=CG.WORKITEM)
                         ].eval_with_dict(params)
     f32consec += mem_map[lp.MemAccess('global', np.dtype(np.float32),
-                        stride=1, direction='load', variable='b')
+                        stride=1, direction='load', variable='b',
+                        count_granularity=CG.WORKITEM)
                          ].eval_with_dict(params)
     assert f64consec == 2*n*m*ell
     assert f32consec == 3*n*m*ell
 
     f64consec = mem_map[lp.MemAccess('global', np.float64,
-                        stride=1, direction='store', variable='e')
+                        stride=1, direction='store', variable='e',
+                        count_granularity=CG.WORKITEM)
                         ].eval_with_dict(params)
     f32consec = mem_map[lp.MemAccess('global', np.float32,
-                        stride=1, direction='store', variable='c')
+                        stride=1, direction='store', variable='c',
+                        count_granularity=CG.WORKITEM)
                         ].eval_with_dict(params)
     assert f64consec == n*m*ell
     assert f32consec == n*m*ell
 
 
+def test_count_granularity_val_checks():
+
+    try:
+        lp.MemAccess(count_granularity=CG.WORKITEM)
+        lp.MemAccess(count_granularity=CG.SUBGROUP)
+        lp.MemAccess(count_granularity=CG.WORKGROUP)
+        lp.MemAccess(count_granularity=None)
+        assert True
+        lp.MemAccess(count_granularity='bushel')
+        assert False
+    except ValueError:
+        assert True
+
+    try:
+        lp.Op(count_granularity=CG.WORKITEM)
+        lp.Op(count_granularity=CG.SUBGROUP)
+        lp.Op(count_granularity=CG.WORKGROUP)
+        lp.Op(count_granularity=None)
+        assert True
+        lp.Op(count_granularity='bushel')
+        assert False
+    except ValueError:
+        assert True
+
+
 def test_barrier_counter_nobarriers():
 
     knl = lp.make_kernel(
@@ -662,42 +880,48 @@ def test_all_counters_parallel_matmul():
 
     op_map = lp.get_op_map(knl, count_redundant_work=True)
     f32mul = op_map[
-                        lp.Op(np.float32, 'mul')
+                        lp.Op(np.float32, 'mul', CG.WORKITEM)
                         ].eval_with_dict(params)
     f32add = op_map[
-                        lp.Op(np.float32, 'add')
+                        lp.Op(np.float32, 'add', CG.WORKITEM)
                         ].eval_with_dict(params)
     i32ops = op_map[
-                        lp.Op(np.int32, 'add')
+                        lp.Op(np.int32, 'add', CG.WORKITEM)
                         ].eval_with_dict(params)
     i32ops += op_map[
-                        lp.Op(np.dtype(np.int32), 'mul')
+                        lp.Op(np.dtype(np.int32), 'mul', CG.WORKITEM)
                         ].eval_with_dict(params)
 
     assert f32mul+f32add == n*m*ell*2
 
-    op_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+    mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                           subgroup_size=32)
 
-    f32s1lb = op_map[lp.MemAccess('global', np.float32,
-                     stride=1, direction='load', variable='b')
-                     ].eval_with_dict(params)
-    f32s1la = op_map[lp.MemAccess('global', np.float32,
-                     stride=1, direction='load', variable='a')
-                     ].eval_with_dict(params)
+    f32s1lb = mem_access_map[lp.MemAccess('global', np.float32,
+                             stride=1, direction='load', variable='b',
+                             count_granularity=CG.WORKITEM)
+                             ].eval_with_dict(params)
+    f32s1la = mem_access_map[lp.MemAccess('global', np.float32,
+                             stride=1, direction='load', variable='a',
+                             count_granularity=CG.WORKITEM)
+                             ].eval_with_dict(params)
 
     assert f32s1lb == n*m*ell/bsize
     assert f32s1la == n*m*ell/bsize
 
-    f32coal = op_map[lp.MemAccess('global', np.float32,
-                     stride=1, direction='store', variable='c')
-                     ].eval_with_dict(params)
+    f32coal = mem_access_map[lp.MemAccess('global', np.float32,
+                             stride=1, direction='store', variable='c',
+                             count_granularity=CG.WORKITEM)
+                             ].eval_with_dict(params)
 
     assert f32coal == n*ell
 
     local_mem_map = lp.get_mem_access_map(knl,
-                        count_redundant_work=True).filter_by(mtype=['local'])
+                        count_redundant_work=True,
+                        subgroup_size=32).filter_by(mtype=['local'])
     local_mem_l = local_mem_map[lp.MemAccess('local', np.dtype(np.float32),
-                                             direction='load')
+                                             direction='load',
+                                             count_granularity=CG.WORKITEM)
                                 ].eval_with_dict(params)
     assert local_mem_l == n*m*ell*2
 
@@ -747,28 +971,46 @@ def test_summations_and_filters():
             name="basic", assumptions="n,m,ell >= 1")
 
     knl = lp.add_and_infer_dtypes(knl,
-                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+                    dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+
+    subgroup_size = 32
+
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
 
-    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True)
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, subgroup_size)
 
-    loads_a = mem_map.filter_by(direction=['load'], variable=['a']
+    mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                    subgroup_size=subgroup_size)
+
+    loads_a = mem_map.filter_by(direction=['load'], variable=['a'],
+                                count_granularity=[CG.SUBGROUP]
                                 ).eval_and_sum(params)
-    assert loads_a == 2*n*m*ell
 
-    global_stores = mem_map.filter_by(mtype=['global'], direction=['store']
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert loads_a == (2*n*m*ell)*n_workgroups*subgroups_per_group
+
+    global_stores = mem_map.filter_by(mtype=['global'], direction=['store'],
+                                      count_granularity=[CG.SUBGROUP]
                                       ).eval_and_sum(params)
-    assert global_stores == n*m*ell + n*m
 
-    ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load']
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert global_stores == (n*m*ell + n*m)*n_workgroups*subgroups_per_group
+
+    ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load'],
+                                 count_granularity=[CG.SUBGROUP]
                                  ).to_bytes().eval_and_sum(params)
-    st_bytes = mem_map.filter_by(mtype=['global'], direction=['store']
+    st_bytes = mem_map.filter_by(mtype=['global'], direction=['store'],
+                                 count_granularity=[CG.SUBGROUP]
                                  ).to_bytes().eval_and_sum(params)
-    assert ld_bytes == 4*n*m*ell*3 + 8*n*m*2
-    assert st_bytes == 4*n*m*ell + 8*n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert ld_bytes == (4*n*m*ell*3 + 8*n*m*2)*n_workgroups*subgroups_per_group
+    assert st_bytes == (4*n*m*ell + 8*n*m)*n_workgroups*subgroups_per_group
 
     # ignore stride and variable names in this map
     reduced_map = mem_map.group_by('mtype', 'dtype', 'direction')
@@ -776,8 +1018,10 @@ def test_summations_and_filters():
                           ].eval_with_dict(params)
     f64lall = reduced_map[lp.MemAccess('global', np.float64, direction='load')
                           ].eval_with_dict(params)
-    assert f32lall == 3*n*m*ell
-    assert f64lall == 2*n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert f32lall == (3*n*m*ell)*n_workgroups*subgroups_per_group
+    assert f64lall == (2*n*m)*n_workgroups*subgroups_per_group
 
     op_map = lp.get_op_map(knl, count_redundant_work=True)
     #for k, v in op_map.items():
@@ -810,7 +1054,9 @@ def test_summations_and_filters():
         return key.stride < 1 and key.dtype == to_loopy_type(np.float64) and \
                key.direction == 'load'
     s1f64l = mem_map.filter_by_func(func_filter).eval_and_sum(params)
-    assert s1f64l == 2*n*m
+
+    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
+    assert s1f64l == (2*n*m)*n_workgroups*subgroups_per_group
 
 
 def test_strided_footprint():