diff --git a/loopy/statistics.py b/loopy/statistics.py index b467e3334249da7ce1d36caa55c363d4a8941bb8..9ce2bb081eca67cc6f41864c7ce5965e018ce853 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -581,6 +581,11 @@ class MemAccess(Record): A :class:`str` that specifies the variable name of the data accessed. + .. attribute:: variable_tag + + A :class:`str` that specifies the variable tag of a + :class:`pymbolic.primitives.TaggedVariable`. + .. attribute:: count_granularity A :class:`str` that specifies whether this operation should be counted @@ -597,7 +602,8 @@ class MemAccess(Record): """ def __init__(self, mtype=None, dtype=None, lid_strides=None, gid_strides=None, - direction=None, variable=None, count_granularity=None): + direction=None, variable=None, variable_tag=None, + count_granularity=None): if count_granularity not in CountGranularity.ALL+[None]: raise ValueError("Op.__init__: count_granularity '%s' is " @@ -607,12 +613,14 @@ class MemAccess(Record): if dtype is None: Record.__init__(self, mtype=mtype, dtype=dtype, lid_strides=lid_strides, gid_strides=gid_strides, direction=direction, - variable=variable, count_granularity=count_granularity) + variable=variable, variable_tag=variable_tag, + count_granularity=count_granularity) else: from loopy.types import to_loopy_type Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype), lid_strides=lid_strides, gid_strides=gid_strides, direction=direction, variable=variable, + variable_tag=variable_tag, count_granularity=count_granularity) def __hash__(self): @@ -622,7 +630,7 @@ class MemAccess(Record): def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness - return "MemAccess(%s, %s, %s, %s, %s, %s, %s)" % ( + return "MemAccess(%s, %s, %s, %s, %s, %s, %s, %s)" % ( self.mtype, self.dtype, None if self.lid_strides is None else dict( @@ -631,6 +639,7 @@ class MemAccess(Record): sorted(six.iteritems(self.gid_strides))), self.direction, self.variable, + self.variable_tag, self.count_granularity) # }}} @@ -985,6 +994,10 @@ class GlobalMemAccessCounter(MemAccessCounter): def map_subscript(self, expr): name = expr.aggregate.name + try: + var_tag = expr.aggregate.tag + except AttributeError: + var_tag = None if name in self.knl.arg_dict: array = self.knl.arg_dict[name] @@ -1013,6 +1026,7 @@ class GlobalMemAccessCounter(MemAccessCounter): lid_strides=dict(sorted(six.iteritems(lid_strides))), gid_strides=dict(sorted(six.iteritems(gid_strides))), variable=name, + variable_tag=var_tag, count_granularity=count_granularity ): 1} ) + self.rec(expr.index_tuple) @@ -1634,6 +1648,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, gid_strides=mem_access.gid_strides, direction=mem_access.direction, variable=mem_access.variable, + variable_tag=mem_access.variable_tag, count_granularity=mem_access.count_granularity), ct) for mem_access, ct in six.iteritems(access_map.count_map)), diff --git a/test/test_statistics.py b/test/test_statistics.py index 41b44b5a7e9bbfe8f371e6a605ccfa8068a563b6..b29edf1ed05f7728b2cbe5b5ad8a74c26944ed8c 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1060,6 +1060,65 @@ def test_all_counters_parallel_matmul(): assert local_mem_s == m*2/bsize*n_subgroups +def test_mem_access_tagged_variables(): + bsize = 16 + knl = lp.make_kernel( + "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<ell}", + [ + "c$mmresult[i, j] = sum(k, a$mmaload[i, k]*b$mmbload[k, j])" + ], + name="matmul", assumptions="n,m,ell >= 1") + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) + knl = lp.split_iname(knl, "i", bsize, outer_tag="g.0", inner_tag="l.1") + knl = lp.split_iname(knl, "j", bsize, outer_tag="g.1", inner_tag="l.0") + knl = lp.split_iname(knl, "k", bsize) + # knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"], default_tag="l.auto") + # knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"], default_tag="l.auto") + + n = 512 + m = 256 + ell = 128 + params = {'n': n, 'm': m, 'ell': ell} + group_size = bsize*bsize + n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group + + mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True, + subgroup_size=SGS) + + f32s1lb = mem_access_map[lp.MemAccess('global', np.float32, + lid_strides={0: 1}, + gid_strides={1: bsize}, + direction='load', variable='b', + variable_tag='mmbload', + count_granularity=CG.WORKITEM) + ].eval_with_dict(params) + f32s1la = mem_access_map[lp.MemAccess('global', np.float32, + lid_strides={1: Variable('m')}, + gid_strides={0: Variable('m')*bsize}, + direction='load', + variable='a', + variable_tag='mmaload', + count_granularity=CG.SUBGROUP) + ].eval_with_dict(params) + + assert f32s1lb == n*m*ell + + # uniform: (count-per-sub-group)*n_subgroups + assert f32s1la == m*n_subgroups + + f32coal = mem_access_map[lp.MemAccess('global', np.float32, + lid_strides={0: 1, 1: Variable('ell')}, + gid_strides={0: Variable('ell')*bsize, 1: bsize}, + direction='store', variable='c', + variable_tag='mmresult', + count_granularity=CG.WORKITEM) + ].eval_with_dict(params) + + assert f32coal == n*ell + + def test_gather_access_footprint(): knl = lp.make_kernel( "{[i,k,j]: 0<=i,j,k<n}",