diff --git a/loopy/statistics.py b/loopy/statistics.py index 9ce2bb081eca67cc6f41864c7ce5965e018ce853..10d29daad062744ca3fbe2dc2261be4cd2c4ca99 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -32,6 +32,7 @@ from functools import reduce from loopy.kernel.data import ( MultiAssignmentBase, TemporaryVariable, AddressSpace) from loopy.diagnostic import warn_with_kernel, LoopyError +from loopy.symbolic import CoefficientCollector from pytools import Record, memoize_method @@ -843,6 +844,19 @@ class ExpressionOpCounter(CounterBase): # }}} +# {{{ modified coefficient collector that ignores denominator of floor div + +class _IndexStrideCoefficientCollector(CoefficientCollector): + + def map_floor_div(self, expr): + from warnings import warn + warn("_IndexStrideCoefficientCollector encountered FloorDiv, ignoring " + "denominator in expression %s" % (expr)) + return self.rec(expr.numerator) + +# }}} + + def _get_lid_and_gid_strides(knl, array, index): # find all local and global index tags and corresponding inames from loopy.symbolic import get_dependencies @@ -870,7 +884,6 @@ def _get_lid_and_gid_strides(knl, array, index): # where l0, l1, l2, g0, g1, and g2 come from flattened index # [... + g2*gid2 + g1*gid1 + g0*gid0 + ... + l2*lid2 + l1*lid1 + l0*lid0] - from loopy.symbolic import CoefficientCollector from loopy.kernel.array import FixedStrideArrayDimTag from pymbolic.primitives import Variable from loopy.symbolic import simplify_using_aff @@ -884,7 +897,7 @@ def _get_lid_and_gid_strides(knl, array, index): for idx, axis_tag in zip(index, array.dim_tags): # collect index coefficients try: - coeffs = CoefficientCollector()( + coeffs = _IndexStrideCoefficientCollector()( simplify_using_aff(knl, idx)) except ExpressionNotAffineError: total_iname_stride = None diff --git a/test/test_statistics.py b/test/test_statistics.py index b29edf1ed05f7728b2cbe5b5ad8a74c26944ed8c..41a88b3864166b81d60ec0468cf9e5fbd07c227c 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1060,6 +1060,50 @@ def test_all_counters_parallel_matmul(): assert local_mem_s == m*2/bsize*n_subgroups +def test_floor_div_coefficient_collector(): + bsize = 16 + + # kernel that shuffles local mem + knl = lp.make_kernel( + "{[i_outer,j_outer,i_inner,j_inner,r]: " + "0<=i_outer loc[i_inner,j_inner] = 3.14 {id=loc_init}", + "loc[i_inner,(j_inner+r+4) %% %d] = loc[i_inner,(j_inner+r) %% %d]" + " {id=add,dep=loc_init}" % (bsize, bsize), + "out0[i_outer*16+i_inner,j_outer*16+j_inner] = loc[i_inner,j_inner]" + " {id=store,dep=add}", + "end", + "end", + ], + name="local", + lang_version=(2018, 2)) + + knl = lp.add_and_infer_dtypes(knl, dict(out0=np.float32)) + knl = lp.tag_inames(knl, "i_outer:g.1,i_inner:l.1,j_outer:g.0,j_inner:l.0") + + n = 512 + rept = 64 + params = {"n": n, "rept": rept} + group_size = bsize*bsize + n_workgroups = div_ceil(n, bsize)*div_ceil(n, bsize) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group + + # count local f32 accesses + f32_local = lp.get_mem_access_map( + knl, count_redundant_work=True, subgroup_size=SGS + ).filter_by(dtype=[np.float32], mtype=["local"]).eval_and_sum(params) + + # (count-per-sub-group)*n_subgroups + assert f32_local == 2*(rept+1)*n_subgroups + + def test_mem_access_tagged_variables(): bsize = 16 knl = lp.make_kernel(