From dbb086e4fdbf687dd340b8ba4dcffa8ee574d631 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 1 Feb 2021 17:08:41 -0600 Subject: [PATCH] tests statistics for callable kernels --- loopy/statistics.py | 10 ++++---- test/test_statistics.py | 51 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 9257cafc1..34027a5a0 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -98,6 +98,7 @@ def _get_param_tuple(obj): class GuardedPwQPolynomial: def __init__(self, pwqpolynomial, valid_domain): + assert isinstance(pwqpolynomial, isl.PwQPolynomial) self.pwqpolynomial = pwqpolynomial self.valid_domain = valid_domain @@ -664,10 +665,10 @@ class Op(ImmutableRecord): def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness if self.kernel_name is not None: - return (f"Op({self.dtype}, {self.name}, {self.count_granularity}," - f" {self.kernel_name})") + return (f'Op("{self.dtype}", "{self.name}", "{self.count_granularity}",' + f' "{self.kernel_name}")') else: - return f"Op({self.dtype}, {self.name}, {self.count_granularity})" + return f'Op("{self.dtype}", "{self.name}", "{self.count_granularity}")' # }}} @@ -1548,7 +1549,8 @@ def get_unused_hw_axes_factor(knl, callables_table, insn, def count_inames_domain(knl, inames): space = get_kernel_parameter_space(knl) if not inames: - return get_kernel_zero_pwqpolynomial(knl) + 1 + return add_assumptions_guard(knl, + get_kernel_zero_pwqpolynomial(knl) + 1) inames_domain = knl.get_inames_domain(inames) domain = inames_domain.project_out_except(inames, [dim_type.set]) diff --git a/test/test_statistics.py b/test/test_statistics.py index 4136f8d06..ca38b9af6 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1400,6 +1400,57 @@ def test_strided_footprint(): assert 2*num < denom +def test_stats_on_callable_kernel(): + callee = lp.make_function( + "{[i, j]: 0<=i, j< 20}", + """ + y[i] = sum(j, A[i,j]*x[j]) + """, name="matvec20x20") + + caller = lp.make_kernel( + "{:}", + """ + y[:] = matvec20x20(A[:,:], x[:]) + """, + [ + lp.GlobalArg("x,y", shape=(20,), dtype=np.float), + lp.GlobalArg("A", shape=(20, 20), dtype=np.float), + ], + name="matvec") + caller = lp.merge([caller, callee]) + + op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, + count_within_subscripts=True) + f64_add = op_map.filter_by(name="add").eval_and_sum({}) + assert f64_add == 400 + + +def test_stats_on_callable_kernel_within_loop(): + callee = lp.make_function( + "{[i, j]: 0<=i, j< 20}", + """ + y[i] = sum(j, A[i,j]*x[j]) + """, name="matvec20x20") + + caller = lp.make_kernel( + "{[i]: 0<=i< 20}", + """ + y[i, :] = matvec20x20(A[:,:], x[i, :]) + """, + [ + lp.GlobalArg("x,y", shape=(20, 20), dtype=np.float), + lp.GlobalArg("A", shape=(20, 20), dtype=np.float), + ], + name="matmat") + caller = lp.merge([caller, callee]) + + op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, + count_within_subscripts=True) + + f64_add = op_map.filter_by(name="add").eval_and_sum({}) + assert f64_add == 8000 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab