diff --git a/loopy/statistics.py b/loopy/statistics.py index 8d9c371a752923460cca20192dd23d34a366aa02..44d55d10f9997d2e2cec74bf21c63a089c94a524 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -929,7 +929,7 @@ class AccessFootprintGatherer(CombineMapper): # {{{ count -def count(kernel, set): +def count(kernel, set, space=None): try: return set.card() except AttributeError: @@ -958,7 +958,11 @@ def count(kernel, set): if stride is None: stride = 1 - length = isl.PwQPolynomial.from_pw_aff(dmax - dmin + stride) + length_pwaff = dmax - dmin + stride + if space is not None: + length_pwaff = length_pwaff.align_params(space) + + length = isl.PwQPolynomial.from_pw_aff(length_pwaff) length = length.scale_down_val(stride) if bset_count is None: @@ -1068,10 +1072,15 @@ def count_insn_runs(knl, insn, disregard_local_axes=False): inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except( insn_inames, [dim_type.set])) - c = count(knl, domain) - return (c * get_unused_hw_axes_factor(knl, insn, - disregard_local_axes=disregard_local_axes, - space=c.space)) + + space = isl.Space.create_from_names(isl.DEFAULT_CONTEXT, + set=[], params=knl.outer_params()) + + c = count(knl, domain, space=space) + unused_fac = get_unused_hw_axes_factor(knl, insn, + disregard_local_axes=disregard_local_axes, + space=space) + return c * unused_fac # }}}