diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 33e91b4c7ebda4b68740e7d876de37ac611e9dfd..840b2df418d694fcf4bce0c080282bc5ee782b7c 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -512,4 +512,62 @@ def dim_max_with_elimination(obj, idx): # }}} +def get_simple_strides(bset, key_by="name"): + """Return a dictionary from inames to strides in bset. Each stride is + returned as a :class:`islpy.Val`. If no stride can be determined, the + corresponding key will not be present in the returned dictionary. + + This only recognizes simple strides involving single variables. + + :arg key_by: "index" or "name" + """ + result = {} + + lspace = bset.get_local_space() + for idiv in range(lspace.dim(dim_type.div)): + div = lspace.get_div(idiv) + + # check for sub-divs + supported = True + for dim_idx in range(div.dim(dim_type.div)): + coeff_val = div.get_coefficient_val(dim_type.div, dim_idx) + if not coeff_val.is_zero(): + # sub-divs not supported + supported = False + break + + if not supported: + continue + + denom = div.get_denominator_val().to_python() + + inames_and_coeffs = [] + for dt in [dim_type.param, dim_type.in_]: + for dim_idx in range(div.dim(dt)): + coeff_val = div.get_coefficient_val(dt, dim_idx) * denom + if not coeff_val.is_zero(): + inames_and_coeffs.append((dt, dim_idx, coeff_val)) + + if len(inames_and_coeffs) != 1: + continue + + (dt, dim_idx, coeff), = inames_and_coeffs + + if coeff != 1: + # not supported + continue + + if key_by == "name": + key = bset.get_dim_name(dt, dim_idx) + elif key_by == "index": + key_dt = dt if dt != dim_type.in_ else dim_type.set + + key = (key_dt, dim_idx) + else: + raise ValueError("invalid value of 'key_by") + + result[key] = denom + + return result + # vim: foldmethod=marker diff --git a/loopy/statistics.py b/loopy/statistics.py index 9ef292f63c12ef06f553738752d9ea894b14eb0c..40bf6da2bf8489a9dd4e3dd5b3b37ca908def00d 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -456,34 +456,95 @@ class AccessFootprintGatherer(CombineMapper): # {{{ count -def count(kernel, bset): +def count(kernel, set): try: - return bset.card() + return set.card() except AttributeError: pass - if not bset.is_box(): - from loopy.diagnostic import warn - warn(kernel, "count_overestimate", - "Barvinok wrappers are not installed. " - "Counting routines may overestimate the " - "number of integer points in your loop " - "domain.") + count = 0 - result = None + set = set.make_disjoint() - for i in range(bset.dim(isl.dim_type.set)): - dmax = bset.dim_max(i) - dmin = bset.dim_min(i) + from loopy.isl_helpers import get_simple_strides - length = isl.PwQPolynomial.from_pw_aff(dmax - dmin + 1) + for bset in set.get_basic_sets(): + bset_count = None + bset_rebuilt = bset.universe(bset.space) - if result is None: - result = length - else: - result = result * length + bset_strides = get_simple_strides(bset, key_by="index") - return result + for i in range(bset.dim(isl.dim_type.set)): + dmax = bset.dim_max(i) + dmin = bset.dim_min(i) + + stride = bset_strides.get((dim_type.set, i)) + if stride is None: + stride = 1 + + length = isl.PwQPolynomial.from_pw_aff(dmax - dmin + stride) + length = length.scale_down_val(stride) + + if bset_count is None: + bset_count = length + else: + bset_count = bset_count * length + + # {{{ rebuild check domain + + zero = isl.Aff.zero_on_domain(isl.LocalSpace.from_space(bset.space)) + iname = isl.PwAff.from_aff( + zero.set_coefficient_val(isl.dim_type.in_, i, 1)) + dmin_matched = dmin.insert_dims( + dim_type.in_, 0, bset.dim(isl.dim_type.set)) + dmax_matched = dmax.insert_dims( + dim_type.in_, 0, bset.dim(isl.dim_type.set)) + for idx in range(bset.dim(isl.dim_type.set)): + if bset.has_dim_id(isl.dim_type.set, idx): + dim_id = bset.get_dim_id(isl.dim_type.set, idx) + dmin_matched = dmin_matched.set_dim_id( + isl.dim_type.in_, idx, dim_id) + dmax_matched = dmax_matched.set_dim_id( + isl.dim_type.in_, idx, dim_id) + + bset_rebuilt = ( + bset_rebuilt + & iname.le_set(dmax_matched) + & iname.ge_set(dmin_matched) + & (iname-dmin_matched).mod_val(stride).eq_set(zero)) + + # }}} + + if bset_count is not None: + count += bset_count + + is_subset = bset <= bset_rebuilt + is_superset = bset >= bset_rebuilt + + if not (is_subset and is_superset): + if is_subset: + from loopy.diagnostic import warn + warn(kernel, "count_overestimate", + "Barvinok wrappers are not installed. " + "Counting routines have overestimated the " + "number of integer points in your loop " + "domain.") + elif is_superset: + from loopy.diagnostic import warn + warn(kernel, "count_underestimate", + "Barvinok wrappers are not installed. " + "Counting routines have underestimated the " + "number of integer points in your loop " + "domain.") + else: + from loopy.diagnostic import warn + warn(kernel, "count_misestimate", + "Barvinok wrappers are not installed. " + "Counting routines have misestimated the " + "number of integer points in your loop " + "domain.") + + return count # }}} diff --git a/test/test_isl.py b/test/test_isl.py index 9106435666e92b98016346ae7732ef5ecf7633ac..3bd3d221e54df685238cfd1532d2b32662aac99f 100644 --- a/test/test_isl.py +++ b/test/test_isl.py @@ -36,6 +36,14 @@ def test_aff_to_expr(): print(aff_to_expr(x)) +def test_aff_to_expr_2(): + from loopy.symbolic import aff_to_expr + x = isl.Aff("[n] -> { [i0] -> [(-i0 + 2*floor((i0)/2))] }") + from pymbolic import var + i0 = var("i0") + assert aff_to_expr(x) == (-1)*i0 + 2*(i0 // 2) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: diff --git a/test/test_statistics.py b/test/test_statistics.py index 0fc4fd218ee2d0f8293e36548a1bb741107ff702..9dd36b0d72973400bb7702afa8d03d62754fc98e 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -618,6 +618,22 @@ def test_gather_access_footprint(): print(key, count(knl, footprint)) +def test_gather_access_footprint_2(): + knl = lp.make_kernel( + "{[i]: 0<=i<n}", + "c[2*i] = a[i]", + name="matmul", assumptions="n >= 1") + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32)) + + from loopy.statistics import gather_access_footprints, count + fp = gather_access_footprints(knl) + + params = {"n": 200} + for key, footprint in six.iteritems(fp): + assert count(knl, footprint).eval_with_dict(params) == 200 + print(key, count(knl, footprint)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])