diff --git a/loopy/statistics.py b/loopy/statistics.py index 34027a5a0af0b11bef59ce9914d8573c484732ad..c8689605420947c5cdf0b58c0dcb31e3b01014a6 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -841,21 +841,22 @@ class CounterBase(CombineMapper): assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] - from loopy.kernel.function_interface import CallableKernel + from loopy.kernel.function_interface import (CallableKernel, + get_kw_pos_association) from loopy.kernel.data import ValueArg if isinstance(clbl, CallableKernel): sub_result = self.kernel_rec(clbl.subkernel) + _, pos_to_kw = get_kw_pos_association(clbl.subkernel) - arg_dict = { - arg.name: value - for arg, value in zip( - clbl.subkernel.args, - expr.parameters) - if isinstance(arg, ValueArg)} + subst_dict = { + pos_to_kw[i]: param + for i, param in enumerate(expr.parameters) + if isinstance(clbl.subkernel.arg_dict[pos_to_kw[i]], + ValueArg)} return subst_into_to_count_map( self.param_space, - sub_result, arg_dict) \ + sub_result, subst_dict) \ + self.rec(expr.parameters) else: diff --git a/test/test_statistics.py b/test/test_statistics.py index ca38b9af63e7ca620ff001b368cc349761936da0..4991793516d18814bffdf5948bf2e81814d6d65a 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1451,6 +1451,34 @@ def test_stats_on_callable_kernel_within_loop(): assert f64_add == 8000 +def test_callable_kernel_with_substitution(): + callee = lp.make_function( + "{[i, j]: 0<=i, j< n}", + """ + y[i] = sum(j, A[i,j]*x[j]) + """, + [lp.ValueArg("n"), ...], + name="matvec") + + caller = lp.make_kernel( + "{[i]: 0<=i< 20}", + """ + y[i, :] = matvec(20, A[:,:], x[i, :]) + """, + [ + lp.GlobalArg("x,y", shape=(20, 20), dtype=np.float), + lp.GlobalArg("A", shape=(20, 20), dtype=np.float), + ], + name="matmat") + caller = lp.merge([caller, callee]) + + op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True, + count_within_subscripts=True) + + f64_add = op_map.filter_by(name="add").eval_and_sum({}) + assert f64_add == 8000 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])