diff --git a/loopy/statistics.py b/loopy/statistics.py
index cee28b24f8bdd44392f41f437c6042f3aa08ce2c..2df3093d1b58babd35c61efdbcee20bba243c643 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -715,7 +715,8 @@ class ExpressionOpCounter(CounterBase):
         return ToCountMap(
                     {Op(dtype=self.type_inf(expr),
                         name='func:'+str(expr.function),
-                        count_granularity=CountGranularity.WORKITEM): 1}
+                        #count_granularity=CountGranularity.WORKITEM): 1}
+                        count_granularity=CountGranularity.SUBGROUP): 1}
                     ) + self.rec(expr.parameters)
 
     def map_subscript(self, expr):
@@ -726,7 +727,8 @@ class ExpressionOpCounter(CounterBase):
         return ToCountMap(
                     {Op(dtype=self.type_inf(expr),
                         name='add',
-                        count_granularity=CountGranularity.WORKITEM):
+                        #count_granularity=CountGranularity.WORKITEM):
+                        count_granularity=CountGranularity.SUBGROUP):
                      len(expr.children)-1}
                     ) + sum(self.rec(child) for child in expr.children)
 
@@ -735,18 +737,21 @@ class ExpressionOpCounter(CounterBase):
         assert expr.children
         return sum(ToCountMap({Op(dtype=self.type_inf(expr),
                                   name='mul',
-                                  count_granularity=CountGranularity.WORKITEM): 1})
+                                  #count_granularity=CountGranularity.WORKITEM): 1})
+                                  count_granularity=CountGranularity.SUBGROUP): 1})
                    + self.rec(child)
                    for child in expr.children
                    if not is_zero(child + 1)) + \
                    ToCountMap({Op(dtype=self.type_inf(expr),
                                   name='mul',
-                                  count_granularity=CountGranularity.WORKITEM): -1})
+                                  #count_granularity=CountGranularity.WORKITEM): -1})
+                                  count_granularity=CountGranularity.SUBGROUP): -1})
 
     def map_quotient(self, expr, *args):
         return ToCountMap({Op(dtype=self.type_inf(expr),
                               name='div',
-                              count_granularity=CountGranularity.WORKITEM): 1}) \
+                              #count_granularity=CountGranularity.WORKITEM): 1}) \
+                              count_granularity=CountGranularity.SUBGROUP): 1}) \
                                 + self.rec(expr.numerator) \
                                 + self.rec(expr.denominator)
 
@@ -756,14 +761,16 @@ class ExpressionOpCounter(CounterBase):
     def map_power(self, expr):
         return ToCountMap({Op(dtype=self.type_inf(expr),
                               name='pow',
-                              count_granularity=CountGranularity.WORKITEM): 1}) \
+                              #count_granularity=CountGranularity.WORKITEM): 1}) \
+                              count_granularity=CountGranularity.SUBGROUP): 1}) \
                                 + self.rec(expr.base) \
                                 + self.rec(expr.exponent)
 
     def map_left_shift(self, expr):
         return ToCountMap({Op(dtype=self.type_inf(expr),
                               name='shift',
-                              count_granularity=CountGranularity.WORKITEM): 1}) \
+                              #count_granularity=CountGranularity.WORKITEM): 1}) \
+                              count_granularity=CountGranularity.SUBGROUP): 1}) \
                                 + self.rec(expr.shiftee) \
                                 + self.rec(expr.shift)
 
@@ -772,13 +779,15 @@ class ExpressionOpCounter(CounterBase):
     def map_bitwise_not(self, expr):
         return ToCountMap({Op(dtype=self.type_inf(expr),
                               name='bw',
-                              count_granularity=CountGranularity.WORKITEM): 1}) \
+                              #count_granularity=CountGranularity.WORKITEM): 1}) \
+                              count_granularity=CountGranularity.SUBGROUP): 1}) \
                                 + self.rec(expr.child)
 
     def map_bitwise_or(self, expr):
         return ToCountMap({Op(dtype=self.type_inf(expr),
                               name='bw',
-                              count_granularity=CountGranularity.WORKITEM):
+                              #count_granularity=CountGranularity.WORKITEM):
+                              count_granularity=CountGranularity.SUBGROUP):
                            len(expr.children)-1}) \
                                 + sum(self.rec(child) for child in expr.children)
 
@@ -802,7 +811,8 @@ class ExpressionOpCounter(CounterBase):
     def map_min(self, expr):
         return ToCountMap({Op(dtype=self.type_inf(expr),
                               name='maxmin',
-                              count_granularity=CountGranularity.WORKITEM):
+                              #count_granularity=CountGranularity.WORKITEM):
+                              count_granularity=CountGranularity.SUBGROUP):
                            len(expr.children)-1}) \
                + sum(self.rec(child) for child in expr.children)
 
@@ -1329,14 +1339,109 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False,
     knl = infer_unknown_types(knl, expect_completion=True)
     knl = preprocess_kernel(knl)
 
+    if not isinstance(subgroup_size, int):
+        # try to find subgroup_size
+        subgroup_size_guess = _find_subgroup_size_for_knl(knl)
+
+        if subgroup_size is None:
+            if subgroup_size_guess is None:
+                # 'guess' was not passed and either no target device found
+                # or get_simd_group_size returned None
+                raise ValueError("No sub-group size passed, no target device found. "
+                                 "Either (1) pass integer value for subgroup_size, "
+                                 "(2) ensure that kernel.target is PyOpenClTarget "
+                                 "and kernel.target.device is set, or (3) pass "
+                                 "subgroup_size='guess' and hope for the best.")
+            else:
+                subgroup_size = subgroup_size_guess
+
+        elif subgroup_size == 'guess':
+            if subgroup_size_guess is None:
+                # unable to get subgroup_size from device, so guess
+                subgroup_size = 32
+                warn_with_kernel(knl, "get_op_map_guessing_subgroup_size",
+                                 "get_op_map: 'guess' sub-group size "
+                                 "passed, no target device found, wildly guessing "
+                                 "that sub-group size is %d." % (subgroup_size))
+            else:
+                subgroup_size = subgroup_size_guess
+        else:
+            raise ValueError("Invalid value for subgroup_size: %s. subgroup_size "
+                             "must be integer, 'guess', or, if you're feeling "
+                             "lucky, None." % (subgroup_size))
+
+    # ------------------------------
+    #class CacheHolder(object):
+    #    pass
+
+    #cache_holder = CacheHolder()
+    #from pytools import memoize_in
+
+    #@memoize_in(cache_holder, "insn_count")
+    def get_insn_count(knl, insn, count_granularity=CountGranularity.WORKITEM):
+
+        if count_granularity is None:
+            warn_with_kernel(knl, "get_insn_count_assumes_granularity",
+                             "get_insn_count: No count granularity passed for "
+                             "Op, assuming %s granularity."
+                             % (CountGranularity.WORKITEM))
+            count_granularity == CountGranularity.WORKITEM
+
+        if count_granularity == CountGranularity.WORKITEM:
+            return count_insn_runs(
+                knl, insn, count_redundant_work=count_redundant_work,
+                disregard_local_axes=False)
+
+        ct_disregard_local = count_insn_runs(
+                knl, insn, disregard_local_axes=True,
+                count_redundant_work=count_redundant_work)
+
+        if count_granularity == CountGranularity.WORKGROUP:
+            return ct_disregard_local
+        elif count_granularity == CountGranularity.SUBGROUP:
+            # get the group size
+            from loopy.symbolic import aff_to_expr
+            _, local_size = knl.get_grid_size_upper_bounds()
+            workgroup_size = 1
+            if local_size:
+                for size in local_size:
+                    s = aff_to_expr(size)
+                    if not isinstance(s, int):
+                        raise LoopyError("Cannot count insn with %s granularity, "
+                                         "work-group size is not integer: %s"
+                                         % (CountGranularity.SUBGROUP, local_size))
+                    workgroup_size *= s
+
+            warn_with_kernel(knl, "insn_count_subgroups_upper_bound",
+                    "get_insn_count: when counting instruction %s with "
+                    "count_granularity=%s, using upper bound for work-group size "
+                    "(%d work-items) to compute sub-groups per work-group. When "
+                    "multiple device programs present, actual sub-group count may be"
+                    "lower." % (insn, CountGranularity.SUBGROUP, workgroup_size))
+
+            from pytools import div_ceil
+            return ct_disregard_local*div_ceil(workgroup_size, subgroup_size)
+        else:
+            # this should not happen since this is enforced in Op
+            raise ValueError("get_insn_count: count_granularity '%s' is"
+                    "not allowed. count_granularity options: %s"
+                    % (count_granularity, CountGranularity.ALL+[None]))
+    # ------------------------------
+
     op_map = ToCountMap()
     op_counter = ExpressionOpCounter(knl)
     for insn in knl.instructions:
         if isinstance(insn, (CallInstruction, CInstruction, Assignment)):
             ops = op_counter(insn.assignee) + op_counter(insn.expression)
-            op_map = op_map + ops*count_insn_runs(
-                    knl, insn,
-                    count_redundant_work=count_redundant_work)
+            #op_map = op_map + ops*count_insn_runs(
+            #        knl, insn,
+            #        count_redundant_work=count_redundant_work)
+            for key, val in six.iteritems(ops):
+                op_map = (
+                        op_map
+                        + ToCountMap({key: val})
+                        * get_insn_count(knl, insn, key.count_granularity))
+
         elif isinstance(insn, (NoOpInstruction, BarrierInstruction)):
             pass
         else:
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 79c5ec7da0971b534588be3bfcd58a9f5fc8405a..b5b55347c99df3abfd5301bc037df611a67126f4 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -39,6 +39,9 @@ from pymbolic.primitives import Variable
 from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2  # noqa
 
 
+SGS = 32  # Subgroup size
+
+
 def test_op_counter_basic():
 
     knl = lp.make_kernel(
@@ -54,21 +57,26 @@ def test_op_counter_basic():
     knl = lp.add_and_infer_dtypes(knl,
                                   dict(a=np.float32, b=np.float32,
                                        g=np.float64, h=np.float64))
-    op_map = lp.get_op_map(knl, count_redundant_work=True)
+    op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params)
-    f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params)
-    f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params)
-    f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.WORKITEM)
+    f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params)
+    f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params)
+    f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP)].eval_with_dict(params)
+    f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM)
+    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    assert f32add == f32mul == f32div == n*m*ell
-    assert f64mul == n*m
-    assert i32add == n*m*2
+    # (count-per-sub-group)*n_subgroups
+    assert f32add == f32mul == f32div == n*m*ell*n_subgroups
+    assert f64mul == n*m*n_subgroups
+    assert i32add == n*m*2*n_subgroups
 
 
 def test_op_counter_reduction():
@@ -81,15 +89,20 @@ def test_op_counter_reduction():
             name="matmul_serial", assumptions="n,m,ell >= 1")
 
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
-    op_map = lp.get_op_map(knl, count_redundant_work=True)
+    op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params)
-    f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.WORKITEM)
+    f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params)
+    f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    assert f32add == f32mul == n*m*ell
+    # (count-per-sub-group)*n_subgroups
+    assert f32add == f32mul == n*m*ell*n_subgroups
 
     op_map_dtype = op_map.group_by('dtype')
     f32 = op_map_dtype[lp.Op(dtype=np.float32)].eval_with_dict(params)
@@ -111,21 +124,26 @@ def test_op_counter_logic():
             name="logic", assumptions="n,m,ell >= 1")
 
     knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
-    op_map = lp.get_op_map(knl, count_redundant_work=True)
+    op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params)
-    f64add = op_map[lp.Op(np.float64, 'add', CG.WORKITEM)].eval_with_dict(params)
-    f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.WORKITEM)
+    f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params)
+    f64add = op_map[lp.Op(np.float64, 'add', CG.SUBGROUP)].eval_with_dict(params)
+    f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM)
+    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    assert f32mul == n*m
-    assert f64div == 2*n*m  # TODO why?
-    assert f64add == n*m
-    assert i32add == n*m
+    # (count-per-sub-group)*n_subgroups
+    assert f32mul == n*m*n_subgroups
+    assert f64div == 2*n*m*n_subgroups  # TODO why?
+    assert f64add == n*m*n_subgroups
+    assert i32add == n*m*n_subgroups
 
 
 def test_op_counter_specialops():
@@ -143,27 +161,32 @@ def test_op_counter_specialops():
     knl = lp.add_and_infer_dtypes(knl,
                                   dict(a=np.float32, b=np.float32,
                                        g=np.float64, h=np.float64))
-    op_map = lp.get_op_map(knl, count_redundant_work=True)
+    op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params)
-    f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params)
-    f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params)
-    f64pow = op_map[lp.Op(np.float64, 'pow', CG.WORKITEM)].eval_with_dict(params)
-    f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.WORKITEM)
+    f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params)
+    f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP)].eval_with_dict(params)
+    f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params)
+    f64pow = op_map[lp.Op(np.float64, 'pow', CG.SUBGROUP)].eval_with_dict(params)
+    f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM)
+    i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.WORKITEM)
+    f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin', CG.WORKITEM)
+    f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    assert f32div == 2*n*m*ell
-    assert f32mul == f32add == n*m*ell
-    assert f64add == 3*n*m
-    assert f64pow == i32add == f64rsq == f64sin == n*m
+    # (count-per-sub-group)*n_subgroups
+    assert f32div == 2*n*m*ell*n_subgroups
+    assert f32mul == f32add == n*m*ell*n_subgroups
+    assert f64add == 3*n*m*n_subgroups
+    assert f64pow == i32add == f64rsq == f64sin == n*m*n_subgroups
 
 
 def test_op_counter_bitwise():
@@ -183,26 +206,31 @@ def test_op_counter_bitwise():
                 a=np.int32, b=np.int32,
                 g=np.int64, h=np.int64))
 
-    op_map = lp.get_op_map(knl, count_redundant_work=True)
+    op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
     n = 512
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
-    i32add = op_map[lp.Op(np.int32, 'add', CG.WORKITEM)].eval_with_dict(params)
-    i32bw = op_map[lp.Op(np.int32, 'bw', CG.WORKITEM)].eval_with_dict(params)
-    i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.WORKITEM)
+    i32add = op_map[lp.Op(np.int32, 'add', CG.SUBGROUP)].eval_with_dict(params)
+    i32bw = op_map[lp.Op(np.int32, 'bw', CG.SUBGROUP)].eval_with_dict(params)
+    i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.SUBGROUP)
                    ].eval_with_dict(params)
-    i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.WORKITEM)
+    i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.WORKITEM)
+    i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.SUBGROUP)
                     ].eval_with_dict(params)
-    i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.WORKITEM)
+    i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.SUBGROUP)
                       ].eval_with_dict(params)
-    assert i32add == n*m+n*m*ell
-    assert i32bw == 2*n*m*ell
-    assert i64bw == 2*n*m
-    assert i64add == i64mul == n*m
-    assert i64shift == 2*n*m
+    # (count-per-sub-group)*n_subgroups
+    assert i32add == n*m+n*m*ell*n_subgroups
+    assert i32bw == 2*n*m*ell*n_subgroups
+    assert i64bw == 2*n*m*n_subgroups
+    assert i64add == i64mul == n*m*n_subgroups
+    assert i64shift == 2*n*m*n_subgroups
 
 
 def test_op_counter_triangular_domain():
@@ -228,15 +256,21 @@ def test_op_counter_triangular_domain():
 
     op_map = lp.get_op_map(
                     knl,
+                    subgroup_size=SGS,
                     count_redundant_work=True
-                    )[lp.Op(np.float64, 'mul', CG.WORKITEM)]
+                    )[lp.Op(np.float64, 'mul', CG.SUBGROUP)]
     value_dict = dict(m=13, n=200)
     flops = op_map.eval_with_dict(value_dict)
 
+    n_workgroups = 1
+    group_size = 1
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
+
     if expect_fallback:
-        assert flops == 144
+        assert flops == 144*n_subgroups
     else:
-        assert flops == 78
+        assert flops == 78*n_subgroups
 
 
 def test_mem_access_counter_basic():
@@ -254,10 +288,8 @@ def test_mem_access_counter_basic():
     knl = lp.add_and_infer_dtypes(knl,
                     dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
 
-    subgroup_size = 32
-
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=subgroup_size)
+                                    subgroup_size=SGS)
 
     n = 512
     m = 256
@@ -266,7 +298,8 @@ def test_mem_access_counter_basic():
 
     n_workgroups = 1
     group_size = 1
-    subgroups_per_group = div_ceil(group_size, subgroup_size)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     f32l = mem_map[lp.MemAccess('global', np.float32,
                         lid_strides={}, gid_strides={},
@@ -289,9 +322,9 @@ def test_mem_access_counter_basic():
                         count_granularity=CG.SUBGROUP)
                     ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32l == (3*n*m*ell)*n_workgroups*subgroups_per_group
-    assert f64l == (2*n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32l == (3*n*m*ell)*n_subgroups
+    assert f64l == (2*n*m)*n_subgroups
 
     f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32),
                         lid_strides={}, gid_strides={},
@@ -304,9 +337,9 @@ def test_mem_access_counter_basic():
                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32s == (n*m*ell)*n_workgroups*subgroups_per_group
-    assert f64s == (n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32s == (n*m*ell)*n_subgroups
+    assert f64s == (n*m)*n_subgroups
 
 
 def test_mem_access_counter_reduction():
@@ -320,10 +353,8 @@ def test_mem_access_counter_reduction():
 
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
 
-    subgroup_size = 32
-
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=subgroup_size)
+                                    subgroup_size=SGS)
     n = 512
     m = 256
     ell = 128
@@ -331,7 +362,8 @@ def test_mem_access_counter_reduction():
 
     n_workgroups = 1
     group_size = 1
-    subgroups_per_group = div_ceil(group_size, subgroup_size)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     f32l = mem_map[lp.MemAccess('global', np.float32,
                         lid_strides={}, gid_strides={},
@@ -344,8 +376,8 @@ def test_mem_access_counter_reduction():
                         count_granularity=CG.SUBGROUP)
                     ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32l == (2*n*m*ell)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32l == (2*n*m*ell)*n_subgroups
 
     f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32),
                         lid_strides={}, gid_strides={},
@@ -353,8 +385,8 @@ def test_mem_access_counter_reduction():
                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32s == (n*ell)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32s == (n*ell)*n_subgroups
 
     ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load']
                                  ).to_bytes().eval_and_sum(params)
@@ -379,10 +411,8 @@ def test_mem_access_counter_logic():
 
     knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
 
-    subgroup_size = 32
-
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=subgroup_size)
+                                    subgroup_size=SGS)
     n = 512
     m = 256
     ell = 128
@@ -390,7 +420,8 @@ def test_mem_access_counter_logic():
 
     n_workgroups = 1
     group_size = 1
-    subgroups_per_group = div_ceil(group_size, subgroup_size)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     reduced_map = mem_map.group_by('mtype', 'dtype', 'direction')
 
@@ -404,10 +435,10 @@ def test_mem_access_counter_logic():
                                        direction='store')
                           ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32_g_l == (2*n*m)*n_workgroups*subgroups_per_group
-    assert f64_g_l == (n*m)*n_workgroups*subgroups_per_group
-    assert f64_g_s == (n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32_g_l == (2*n*m)*n_subgroups
+    assert f64_g_l == (n*m)*n_subgroups
+    assert f64_g_s == (n*m)*n_subgroups
 
 
 def test_mem_access_counter_specialops():
@@ -425,10 +456,8 @@ def test_mem_access_counter_specialops():
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32,
                                             g=np.float64, h=np.float64))
 
-    subgroup_size = 32
-
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=subgroup_size)
+                                    subgroup_size=SGS)
     n = 512
     m = 256
     ell = 128
@@ -436,7 +465,8 @@ def test_mem_access_counter_specialops():
 
     n_workgroups = 1
     group_size = 1
-    subgroups_per_group = div_ceil(group_size, subgroup_size)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     f32 = mem_map[lp.MemAccess('global', np.float32,
                         lid_strides={}, gid_strides={},
@@ -459,9 +489,9 @@ def test_mem_access_counter_specialops():
                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32 == (2*n*m*ell)*n_workgroups*subgroups_per_group
-    assert f64 == (2*n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32 == (2*n*m*ell)*n_subgroups
+    assert f64 == (2*n*m)*n_subgroups
 
     f32 = mem_map[lp.MemAccess('global', np.float32,
                         lid_strides={}, gid_strides={},
@@ -474,16 +504,16 @@ def test_mem_access_counter_specialops():
                         count_granularity=CG.SUBGROUP)
                   ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32 == (n*m*ell)*n_workgroups*subgroups_per_group
-    assert f64 == (n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32 == (n*m*ell)*n_subgroups
+    assert f64 == (n*m)*n_subgroups
 
     filtered_map = mem_map.filter_by(direction=['load'], variable=['a', 'g'],
                          count_granularity=CG.SUBGROUP)
     tot = filtered_map.eval_and_sum(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert tot == (n*m*ell + n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert tot == (n*m*ell + n*m)*n_subgroups
 
 
 def test_mem_access_counter_bitwise():
@@ -503,10 +533,8 @@ def test_mem_access_counter_bitwise():
                 a=np.int32, b=np.int32,
                 g=np.int32, h=np.int32))
 
-    subgroup_size = 32
-
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=subgroup_size)
+                                    subgroup_size=SGS)
     n = 512
     m = 256
     ell = 128
@@ -514,7 +542,8 @@ def test_mem_access_counter_bitwise():
 
     n_workgroups = 1
     group_size = 1
-    subgroups_per_group = div_ceil(group_size, subgroup_size)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     i32 = mem_map[lp.MemAccess('global', np.int32,
                         lid_strides={}, gid_strides={},
@@ -537,8 +566,8 @@ def test_mem_access_counter_bitwise():
                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert i32 == (4*n*m+2*n*m*ell)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert i32 == (4*n*m+2*n*m*ell)*n_subgroups
 
     i32 = mem_map[lp.MemAccess('global', np.int32,
                         lid_strides={}, gid_strides={},
@@ -551,8 +580,8 @@ def test_mem_access_counter_bitwise():
                         count_granularity=CG.SUBGROUP)
                    ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert i32 == (n*m+n*m*ell)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert i32 == (n*m+n*m*ell)*n_subgroups
 
 
 def test_mem_access_counter_mixed():
@@ -571,7 +600,6 @@ def test_mem_access_counter_mixed():
                 x=np.float32))
 
     group_size_0 = 65
-    subgroup_size = 32
 
     knl = lp.split_iname(knl, "j", group_size_0)
     knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"})
@@ -583,10 +611,11 @@ def test_mem_access_counter_mixed():
 
     n_workgroups = div_ceil(ell, group_size_0)
     group_size = group_size_0
-    subgroups_per_group = div_ceil(group_size, subgroup_size)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=subgroup_size)
+                                    subgroup_size=SGS)
     f64uniform = mem_map[lp.MemAccess('global', np.float64,
                                 lid_strides={}, gid_strides={},
                                 direction='load', variable='g',
@@ -617,9 +646,9 @@ def test_mem_access_counter_mixed():
                                 count_granularity=CG.WORKITEM)
                             ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f64uniform == (2*n*m)*n_workgroups*subgroups_per_group
-    assert f32uniform == (m*n)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f64uniform == (2*n*m)*n_subgroups
+    assert f32uniform == (m*n)*n_subgroups
 
     expect_fallback = False
     import islpy as isl
@@ -651,8 +680,8 @@ def test_mem_access_counter_mixed():
                                 count_granularity=CG.WORKITEM)
                            ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f64uniform == m*n*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f64uniform == m*n*n_subgroups
 
     if expect_fallback:
         if ell < group_size_0:
@@ -681,7 +710,7 @@ def test_mem_access_counter_nonconsec():
     knl = lp.tag_inames(knl, {"i_inner": "l.0", "i_outer": "g.0"})
 
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=32)  # noqa
+                                    subgroup_size=SGS)  # noqa
     n = 512
     m = 256
     ell = 128
@@ -939,30 +968,35 @@ def test_all_counters_parallel_matmul():
     m = 256
     ell = 128
     params = {'n': n, 'm': m, 'ell': ell}
+    group_size = bsize*bsize
+    n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     sync_map = lp.get_synchronization_map(knl)
     assert len(sync_map) == 2
     assert sync_map["kernel_launch"].eval_with_dict(params) == 1
     assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize
 
-    op_map = lp.get_op_map(knl, count_redundant_work=True)
+    op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
     f32mul = op_map[
-                        lp.Op(np.float32, 'mul', CG.WORKITEM)
+                        lp.Op(np.float32, 'mul', CG.SUBGROUP)
                         ].eval_with_dict(params)
     f32add = op_map[
-                        lp.Op(np.float32, 'add', CG.WORKITEM)
+                        lp.Op(np.float32, 'add', CG.SUBGROUP)
                         ].eval_with_dict(params)
     i32ops = op_map[
-                        lp.Op(np.int32, 'add', CG.WORKITEM)
+                        lp.Op(np.int32, 'add', CG.SUBGROUP)
                         ].eval_with_dict(params)
     i32ops += op_map[
-                        lp.Op(np.dtype(np.int32), 'mul', CG.WORKITEM)
+                        lp.Op(np.dtype(np.int32), 'mul', CG.SUBGROUP)
                         ].eval_with_dict(params)
 
-    assert f32mul+f32add == n*m*ell*2
+    # (count-per-sub-group)*n_subgroups
+    assert f32mul+f32add == m*2*n_subgroups
 
     mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                           subgroup_size=32)
+                                           subgroup_size=SGS)
 
     f32s1lb = mem_access_map[lp.MemAccess('global', np.float32,
                              lid_strides={0: 1, 1: Variable('ell')},
@@ -991,7 +1025,7 @@ def test_all_counters_parallel_matmul():
 
     local_mem_map = lp.get_mem_access_map(knl,
                         count_redundant_work=True,
-                        subgroup_size=32).filter_by(mtype=['local'])
+                        subgroup_size=SGS).filter_by(mtype=['local'])
 
     local_mem_l = local_mem_map.filter_by(direction=['load']
                                           ).eval_and_sum(params)
@@ -1067,8 +1101,6 @@ def test_summations_and_filters():
     knl = lp.add_and_infer_dtypes(knl,
                     dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
 
-    subgroup_size = 32
-
     n = 512
     m = 256
     ell = 128
@@ -1076,24 +1108,25 @@ def test_summations_and_filters():
 
     n_workgroups = 1
     group_size = 1
-    subgroups_per_group = div_ceil(group_size, subgroup_size)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
 
     mem_map = lp.get_mem_access_map(knl, count_redundant_work=True,
-                                    subgroup_size=subgroup_size)
+                                    subgroup_size=SGS)
 
     loads_a = mem_map.filter_by(direction=['load'], variable=['a'],
                                 count_granularity=[CG.SUBGROUP]
                                 ).eval_and_sum(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert loads_a == (2*n*m*ell)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert loads_a == (2*n*m*ell)*n_subgroups
 
     global_stores = mem_map.filter_by(mtype=['global'], direction=['store'],
                                       count_granularity=[CG.SUBGROUP]
                                       ).eval_and_sum(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert global_stores == (n*m*ell + n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert global_stores == (n*m*ell + n*m)*n_subgroups
 
     ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load'],
                                  count_granularity=[CG.SUBGROUP]
@@ -1102,9 +1135,9 @@ def test_summations_and_filters():
                                  count_granularity=[CG.SUBGROUP]
                                  ).to_bytes().eval_and_sum(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert ld_bytes == (4*n*m*ell*3 + 8*n*m*2)*n_workgroups*subgroups_per_group
-    assert st_bytes == (4*n*m*ell + 8*n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert ld_bytes == (4*n*m*ell*3 + 8*n*m*2)*n_subgroups
+    assert st_bytes == (4*n*m*ell + 8*n*m)*n_subgroups
 
     # ignore stride and variable names in this map
     reduced_map = mem_map.group_by('mtype', 'dtype', 'direction')
@@ -1113,11 +1146,11 @@ def test_summations_and_filters():
     f64lall = reduced_map[lp.MemAccess('global', np.float64, direction='load')
                           ].eval_with_dict(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f32lall == (3*n*m*ell)*n_workgroups*subgroups_per_group
-    assert f64lall == (2*n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32lall == (3*n*m*ell)*n_subgroups
+    assert f64lall == (2*n*m)*n_subgroups
 
-    op_map = lp.get_op_map(knl, count_redundant_work=True)
+    op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
     #for k, v in op_map.items():
     #    print(type(k), "\n", k.name, k.dtype, type(k.dtype), " :\n", v)
 
@@ -1149,8 +1182,8 @@ def test_summations_and_filters():
                key.direction == 'load'
     f64l = mem_map.filter_by_func(func_filter).eval_and_sum(params)
 
-    # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group
-    assert f64l == (2*n*m)*n_workgroups*subgroups_per_group
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f64l == (2*n*m)*n_subgroups
 
 
 def test_strided_footprint():