diff --git a/loopy/statistics.py b/loopy/statistics.py
index b467e3334249da7ce1d36caa55c363d4a8941bb8..9ce2bb081eca67cc6f41864c7ce5965e018ce853 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -581,6 +581,11 @@ class MemAccess(Record):
        A :class:`str` that specifies the variable name of the data
        accessed.
 
+    .. attribute:: variable_tag
+
+       A :class:`str` that specifies the variable tag of a
+       :class:`pymbolic.primitives.TaggedVariable`.
+
     .. attribute:: count_granularity
 
        A :class:`str` that specifies whether this operation should be counted
@@ -597,7 +602,8 @@ class MemAccess(Record):
     """
 
     def __init__(self, mtype=None, dtype=None, lid_strides=None, gid_strides=None,
-                 direction=None, variable=None, count_granularity=None):
+                 direction=None, variable=None, variable_tag=None,
+                 count_granularity=None):
 
         if count_granularity not in CountGranularity.ALL+[None]:
             raise ValueError("Op.__init__: count_granularity '%s' is "
@@ -607,12 +613,14 @@ class MemAccess(Record):
         if dtype is None:
             Record.__init__(self, mtype=mtype, dtype=dtype, lid_strides=lid_strides,
                             gid_strides=gid_strides, direction=direction,
-                            variable=variable, count_granularity=count_granularity)
+                            variable=variable, variable_tag=variable_tag,
+                            count_granularity=count_granularity)
         else:
             from loopy.types import to_loopy_type
             Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype),
                             lid_strides=lid_strides, gid_strides=gid_strides,
                             direction=direction, variable=variable,
+                            variable_tag=variable_tag,
                             count_granularity=count_granularity)
 
     def __hash__(self):
@@ -622,7 +630,7 @@ class MemAccess(Record):
 
     def __repr__(self):
         # Record.__repr__ overridden for consistent ordering and conciseness
-        return "MemAccess(%s, %s, %s, %s, %s, %s, %s)" % (
+        return "MemAccess(%s, %s, %s, %s, %s, %s, %s, %s)" % (
             self.mtype,
             self.dtype,
             None if self.lid_strides is None else dict(
@@ -631,6 +639,7 @@ class MemAccess(Record):
                 sorted(six.iteritems(self.gid_strides))),
             self.direction,
             self.variable,
+            self.variable_tag,
             self.count_granularity)
 
 # }}}
@@ -985,6 +994,10 @@ class GlobalMemAccessCounter(MemAccessCounter):
 
     def map_subscript(self, expr):
         name = expr.aggregate.name
+        try:
+            var_tag = expr.aggregate.tag
+        except AttributeError:
+            var_tag = None
 
         if name in self.knl.arg_dict:
             array = self.knl.arg_dict[name]
@@ -1013,6 +1026,7 @@ class GlobalMemAccessCounter(MemAccessCounter):
                             lid_strides=dict(sorted(six.iteritems(lid_strides))),
                             gid_strides=dict(sorted(six.iteritems(gid_strides))),
                             variable=name,
+                            variable_tag=var_tag,
                             count_granularity=count_granularity
                             ): 1}
                           ) + self.rec(expr.index_tuple)
@@ -1634,6 +1648,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
                             gid_strides=mem_access.gid_strides,
                             direction=mem_access.direction,
                             variable=mem_access.variable,
+                            variable_tag=mem_access.variable_tag,
                             count_granularity=mem_access.count_granularity),
                         ct)
                         for mem_access, ct in six.iteritems(access_map.count_map)),
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 41b44b5a7e9bbfe8f371e6a605ccfa8068a563b6..b29edf1ed05f7728b2cbe5b5ad8a74c26944ed8c 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -1060,6 +1060,65 @@ def test_all_counters_parallel_matmul():
     assert local_mem_s == m*2/bsize*n_subgroups
 
 
+def test_mem_access_tagged_variables():
+    bsize = 16
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<ell}",
+            [
+                "c$mmresult[i, j] = sum(k, a$mmaload[i, k]*b$mmbload[k, j])"
+            ],
+            name="matmul", assumptions="n,m,ell >= 1")
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
+    knl = lp.split_iname(knl, "i", bsize, outer_tag="g.0", inner_tag="l.1")
+    knl = lp.split_iname(knl, "j", bsize, outer_tag="g.1", inner_tag="l.0")
+    knl = lp.split_iname(knl, "k", bsize)
+    # knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"], default_tag="l.auto")
+    # knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"], default_tag="l.auto")
+
+    n = 512
+    m = 256
+    ell = 128
+    params = {'n': n, 'm': m, 'ell': ell}
+    group_size = bsize*bsize
+    n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize)
+    subgroups_per_group = div_ceil(group_size, SGS)
+    n_subgroups = n_workgroups*subgroups_per_group
+
+    mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True,
+                                           subgroup_size=SGS)
+
+    f32s1lb = mem_access_map[lp.MemAccess('global', np.float32,
+                             lid_strides={0: 1},
+                             gid_strides={1: bsize},
+                             direction='load', variable='b',
+                             variable_tag='mmbload',
+                             count_granularity=CG.WORKITEM)
+                             ].eval_with_dict(params)
+    f32s1la = mem_access_map[lp.MemAccess('global', np.float32,
+                             lid_strides={1: Variable('m')},
+                             gid_strides={0: Variable('m')*bsize},
+                             direction='load',
+                             variable='a',
+                             variable_tag='mmaload',
+                             count_granularity=CG.SUBGROUP)
+                             ].eval_with_dict(params)
+
+    assert f32s1lb == n*m*ell
+
+    # uniform: (count-per-sub-group)*n_subgroups
+    assert f32s1la == m*n_subgroups
+
+    f32coal = mem_access_map[lp.MemAccess('global', np.float32,
+                             lid_strides={0: 1, 1: Variable('ell')},
+                             gid_strides={0: Variable('ell')*bsize, 1: bsize},
+                             direction='store', variable='c',
+                             variable_tag='mmresult',
+                             count_granularity=CG.WORKITEM)
+                             ].eval_with_dict(params)
+
+    assert f32coal == n*ell
+
+
 def test_gather_access_footprint():
     knl = lp.make_kernel(
             "{[i,k,j]: 0<=i,j,k<n}",