diff --git a/loopy/statistics.py b/loopy/statistics.py index 9257cafc1ea0a56959b9f4901ad03089cd69c998..34027a5a0af0b11bef59ce9914d8573c484732ad 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -98,6 +98,7 @@ def _get_param_tuple(obj): class GuardedPwQPolynomial: def __init__(self, pwqpolynomial, valid_domain): + assert isinstance(pwqpolynomial, isl.PwQPolynomial) self.pwqpolynomial = pwqpolynomial self.valid_domain = valid_domain @@ -664,10 +665,10 @@ class Op(ImmutableRecord): def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness if self.kernel_name is not None: - return (f"Op({self.dtype}, {self.name}, {self.count_granularity}," - f" {self.kernel_name})") + return (f'Op("{self.dtype}", "{self.name}", "{self.count_granularity}",' + f' "{self.kernel_name}")') else: - return f"Op({self.dtype}, {self.name}, {self.count_granularity})" + return f'Op("{self.dtype}", "{self.name}", "{self.count_granularity}")' # }}} @@ -1548,7 +1549,8 @@ def get_unused_hw_axes_factor(knl, callables_table, insn, def count_inames_domain(knl, inames): space = get_kernel_parameter_space(knl) if not inames: - return get_kernel_zero_pwqpolynomial(knl) + 1 + return add_assumptions_guard(knl, + get_kernel_zero_pwqpolynomial(knl) + 1) inames_domain = knl.get_inames_domain(inames) domain = inames_domain.project_out_except(inames, [dim_type.set]) diff --git a/test/test_statistics.py b/test/test_statistics.py index 4136f8d064d50d4493f5a1c4a0c5c03d0351e936..ca38b9af63e7ca620ff001b368cc349761936da0 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1400,6 +1400,57 @@ def test_strided_footprint(): assert 2*num < denom +def test_stats_on_callable_kernel(): + callee = lp.make_function( + "{[i, j]: 0<=i, j< 20}", + """ + y[i] = sum(j, A[i,j]*x[j]) + """, name="matvec20x20") + + caller = lp.make_kernel( + "{:}", + """ + y[:] = matvec20x20(A[:,:], x[:]) + """, + [ + lp.GlobalArg("x,y", shape=(20,), dtype=np.float), + lp.GlobalArg("A", shape=(20, 20), dtype=np.float), + ], + name="matvec") + caller = lp.merge([caller, callee]) + + op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, + count_within_subscripts=True) + f64_add = op_map.filter_by(name="add").eval_and_sum({}) + assert f64_add == 400 + + +def test_stats_on_callable_kernel_within_loop(): + callee = lp.make_function( + "{[i, j]: 0<=i, j< 20}", + """ + y[i] = sum(j, A[i,j]*x[j]) + """, name="matvec20x20") + + caller = lp.make_kernel( + "{[i]: 0<=i< 20}", + """ + y[i, :] = matvec20x20(A[:,:], x[i, :]) + """, + [ + lp.GlobalArg("x,y", shape=(20, 20), dtype=np.float), + lp.GlobalArg("A", shape=(20, 20), dtype=np.float), + ], + name="matmat") + caller = lp.merge([caller, callee]) + + op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, + count_within_subscripts=True) + + f64_add = op_map.filter_by(name="add").eval_and_sum({}) + assert f64_add == 8000 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])