diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 5f0884fd44ed5064f3f195d103b164f2163d1d19..f7ce5d9fc983c2ab946b5d959f283ef9328b7f29 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -616,10 +616,12 @@ def get_simple_strides(bset, key_by="name"): # recognizes constraints of the form # -i0 + 2*floor((i0)/2) == 0 - if aff.dim(dim_type.div) != 1: + divs_with_coeffs = _get_indices_and_coeffs(aff, [dim_type.div]) + if len(divs_with_coeffs) != 1: continue - idiv = 0 + (_, idiv, div_coeff), = divs_with_coeffs + div = aff.get_div(idiv) # check for sub-divs @@ -630,7 +632,7 @@ def get_simple_strides(bset, key_by="name"): denom = div.get_denominator_val().to_python() # if the coefficient in front of the div is not the same as the denominator - if not aff.get_coefficient_val(dim_type.div, idiv).div(denom).is_one(): + if not div_coeff.div(denom).is_one(): # not supported continue diff --git a/test/test_statistics.py b/test/test_statistics.py index a72b62af90050008f837e144f1f28d4a4de1c730..cf86539efec7be7e85fecfadc3b19d26fac7bb6d 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -805,6 +805,32 @@ def test_summations_and_filters(): assert s1f64l == 2*n*m +def test_strided_footprint(): + param_dict = dict(n=2**20) + knl = lp.make_kernel( + "[n] -> {[i]: 0<=i<n}", + [ + "z[i] = x[3*i]" + ], name="s3") + + knl = lp.add_and_infer_dtypes(knl, dict(x=np.float32)) + + unr = 4 + bx = 256 + + knl = lp.split_iname(knl, "i", bx*unr, outer_tag="g.0", slabs=(0, 1)) + knl = lp.split_iname(knl, "i_inner", bx, outer_tag="unr", inner_tag="l.0") + + footprints = lp.gather_access_footprints(knl) + x_l_foot = footprints[('x', 'read')] + + from loopy.statistics import count + num = count(knl, x_l_foot).eval_with_dict(param_dict) + denom = count(knl, x_l_foot.remove_divs()).eval_with_dict(param_dict) + + assert 2*num < denom + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])