From 55f9418f4eb7794c05f14eb34a8157ff10039d38 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 10 Aug 2017 21:25:54 -0500
Subject: [PATCH] Fix recognition of striding constraints in counting when
 other, unused divs are present

---
 loopy/isl_helpers.py    |  8 +++++---
 test/test_statistics.py | 26 ++++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 3 deletions(-)

diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py
index 5f0884fd4..f7ce5d9fc 100644
--- a/loopy/isl_helpers.py
+++ b/loopy/isl_helpers.py
@@ -616,10 +616,12 @@ def get_simple_strides(bset, key_by="name"):
         # recognizes constraints of the form
         #  -i0 + 2*floor((i0)/2) == 0
 
-        if aff.dim(dim_type.div) != 1:
+        divs_with_coeffs = _get_indices_and_coeffs(aff, [dim_type.div])
+        if len(divs_with_coeffs) != 1:
             continue
 
-        idiv = 0
+        (_, idiv, div_coeff), = divs_with_coeffs
+
         div = aff.get_div(idiv)
 
         # check for sub-divs
@@ -630,7 +632,7 @@ def get_simple_strides(bset, key_by="name"):
         denom = div.get_denominator_val().to_python()
 
         # if the coefficient in front of the div is not the same as the denominator
-        if not aff.get_coefficient_val(dim_type.div, idiv).div(denom).is_one():
+        if not div_coeff.div(denom).is_one():
             # not supported
             continue
 
diff --git a/test/test_statistics.py b/test/test_statistics.py
index a72b62af9..cf86539ef 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -805,6 +805,32 @@ def test_summations_and_filters():
     assert s1f64l == 2*n*m
 
 
+def test_strided_footprint():
+    param_dict = dict(n=2**20)
+    knl = lp.make_kernel(
+        "[n] -> {[i]: 0<=i<n}",
+        [
+            "z[i] = x[3*i]"
+        ], name="s3")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(x=np.float32))
+
+    unr = 4
+    bx = 256
+
+    knl = lp.split_iname(knl, "i", bx*unr, outer_tag="g.0", slabs=(0, 1))
+    knl = lp.split_iname(knl, "i_inner", bx, outer_tag="unr", inner_tag="l.0")
+
+    footprints = lp.gather_access_footprints(knl)
+    x_l_foot = footprints[('x', 'read')]
+
+    from loopy.statistics import count
+    num = count(knl, x_l_foot).eval_with_dict(param_dict)
+    denom = count(knl, x_l_foot.remove_divs()).eval_with_dict(param_dict)
+
+    assert 2*num < denom
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab