diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index e9e7c9a447afb559e3536ab3cb1219111a3a2e0d..730d331127e0937ad6037cbc18bc29270a612941 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -587,6 +587,8 @@ def generate_code_v2(program): codegen_results[func_id] = ( generate_code_for_a_single_kernel(in_knl_callable.subkernel, program.program_callables_info, program.target)) + if not in_knl_callable.subkernel.is_called_from_host: + assert codegen_results[func_id].host_program is None device_preambles = set() for cgr in codegen_results.values(): diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py index 90bdbda31410e1d61b2a3e2f504a8dfcb74c23ed..bb62961c535a851ac4dc9e03724db3685395fe7c 100644 --- a/loopy/codegen/control.py +++ b/loopy/codegen/control.py @@ -117,16 +117,19 @@ def generate_code_for_sched_index(codegen_state, sched_index): glob_grid, loc_grid = kernel.get_grid_sizes_for_insn_ids_as_exprs( get_insn_ids_for_block_at(kernel.schedule, sched_index), codegen_state.program_callables_info) - - return merge_codegen_results(codegen_state, [ - codegen_result, - - codegen_state.ast_builder.get_kernel_call( - codegen_state, - sched_item.kernel_name, - glob_grid, loc_grid, - extra_args), - ]) + if kernel.is_called_from_host: + return merge_codegen_results(codegen_state, [ + codegen_result, + + codegen_state.ast_builder.get_kernel_call( + codegen_state, + sched_item.kernel_name, + glob_grid, loc_grid, + extra_args), + ]) + else: + # do not generate host code for callee kernels + return codegen_result elif isinstance(sched_item, EnterLoop): tags = kernel.iname_tags(sched_item.iname) diff --git a/loopy/codegen/result.py b/loopy/codegen/result.py index 00f19d99afa7119f35b188492646669f71b850d1..7950c56b3b62693f974cbcc5ab8686f30fa42cbe 100644 --- a/loopy/codegen/result.py +++ b/loopy/codegen/result.py @@ -292,27 +292,32 @@ def generate_host_or_device_program(codegen_state, schedule_index): else: codegen_result = build_loop_nest(codegen_state, schedule_index) - codegen_result = merge_codegen_results( - codegen_state, - ast_builder.generate_top_of_body(codegen_state) - + temp_decls - + [codegen_result], - collapse=False) - - cur_prog = codegen_result.current_program(codegen_state) - body_ast = cur_prog.ast - fdecl_ast = ast_builder.get_function_declaration( - codegen_state, codegen_result, schedule_index) - - fdef_ast = ast_builder.get_function_definition( - codegen_state, codegen_result, - schedule_index, fdecl_ast, body_ast) - - codegen_result = codegen_result.with_new_program( - codegen_state, - cur_prog.copy( - ast=ast_builder.process_ast(fdef_ast), - body_ast=ast_builder.process_ast(body_ast))) + if (codegen_state.is_generating_device_code) or ( + codegen_state.kernel.is_called_from_host): + codegen_result = merge_codegen_results( + codegen_state, + ast_builder.generate_top_of_body(codegen_state) + + temp_decls + + [codegen_result], + collapse=False) + + cur_prog = codegen_result.current_program(codegen_state) + body_ast = cur_prog.ast + fdecl_ast = ast_builder.get_function_declaration( + codegen_state, codegen_result, schedule_index) + + fdef_ast = ast_builder.get_function_definition( + codegen_state, codegen_result, + schedule_index, fdecl_ast, body_ast) + + codegen_result = codegen_result.with_new_program( + codegen_state, + cur_prog.copy( + ast=ast_builder.process_ast(fdef_ast), + body_ast=ast_builder.process_ast(body_ast))) + else: + codegen_result = codegen_result.copy( + host_program=None) return codegen_result