From e658837714adc9bd738e1670d325b7aaedfff223 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 1 Feb 2021 20:16:18 -0600 Subject: [PATCH] minor fix to correct the substitution of caller args in callee's stats exprs --- loopy/statistics.py | 17 +++++++++-------- test/test_statistics.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 34027a5a0..c86896054 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -841,21 +841,22 @@ class CounterBase(CombineMapper): assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] - from loopy.kernel.function_interface import CallableKernel + from loopy.kernel.function_interface import (CallableKernel, + get_kw_pos_association) from loopy.kernel.data import ValueArg if isinstance(clbl, CallableKernel): sub_result = self.kernel_rec(clbl.subkernel) + _, pos_to_kw = get_kw_pos_association(clbl.subkernel) - arg_dict = { - arg.name: value - for arg, value in zip( - clbl.subkernel.args, - expr.parameters) - if isinstance(arg, ValueArg)} + subst_dict = { + pos_to_kw[i]: param + for i, param in enumerate(expr.parameters) + if isinstance(clbl.subkernel.arg_dict[pos_to_kw[i]], + ValueArg)} return subst_into_to_count_map( self.param_space, - sub_result, arg_dict) \ + sub_result, subst_dict) \ + self.rec(expr.parameters) else: diff --git a/test/test_statistics.py b/test/test_statistics.py index ca38b9af6..499179351 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1451,6 +1451,34 @@ def test_stats_on_callable_kernel_within_loop(): assert f64_add == 8000 +def test_callable_kernel_with_substitution(): + callee = lp.make_function( + "{[i, j]: 0<=i, j< n}", + """ + y[i] = sum(j, A[i,j]*x[j]) + """, + [lp.ValueArg("n"), ...], + name="matvec") + + caller = lp.make_kernel( + "{[i]: 0<=i< 20}", + """ + y[i, :] = matvec(20, A[:,:], x[i, :]) + """, + [ + lp.GlobalArg("x,y", shape=(20, 20), dtype=np.float), + lp.GlobalArg("A", shape=(20, 20), dtype=np.float), + ], + name="matmat") + caller = lp.merge([caller, callee]) + + op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, + count_within_subscripts=True) + + f64_add = op_map.filter_by(name="add").eval_and_sum({}) + assert f64_add == 8000 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab