diff --git a/loopy/schedule/device_mapping.py b/loopy/schedule/device_mapping.py index bf6f67e218aa2ddea183c769f3cf70348b9aec67..86f5c4a2cb6d00a5f38f9ead66202a0a09fa49e8 100644 --- a/loopy/schedule/device_mapping.py +++ b/loopy/schedule/device_mapping.py @@ -38,6 +38,7 @@ def map_schedule_onto_host_or_device(kernel): kernel = kernel.copy(schedule=new_schedule) else: kernel = map_schedule_onto_host_or_device_impl(kernel) + kernel = add_extra_args_to_schedule(kernel) return restore_and_save_temporaries(kernel) @@ -183,6 +184,27 @@ def get_def_and_use_lists_for_all_temporaries(kernel): return def_lists, use_lists + +def get_temporaries_defined_and_used_in_subrange( + kernel, schedule, start_idx, end_idx): + defs = set() + uses = set() + + from loopy.schedule import RunInstruction + + for idx in range(start_idx, end_idx + 1): + sched_item = schedule[idx] + if isinstance(sched_item, RunInstruction): + insn = kernel.id_to_insn[sched_item.insn_id] + defs.update( + filter_temporaries( + kernel, get_def_set(insn))) + uses.update( + filter_temporaries( + kernel, get_use_set(insn))) + + return defs, uses + # }}} @@ -439,6 +461,11 @@ def restore_and_save_temporaries(kernel): inter_kernel_temporaries |= filter_out_subscripts(live_in[idx]) call_count += 1 + if call_count == 1: + # Single call corresponds to a kernel which has not been split - + # no need for restores / spills of temporaries. + return kernel + name_gen = kernel.get_var_name_generator() new_temporaries = determine_temporaries_to_promote( kernel, inter_kernel_temporaries, name_gen) @@ -459,31 +486,20 @@ def restore_and_save_temporaries(kernel): idx += 1 continue - subkernel_defs = set() - subkernel_uses = set() subkernel_prolog = [] subkernel_epilog = [] subkernel_schedule = [] - # {{{ Determine what to load / spill - start_idx = idx - idx += 1 - # Analyze the variables used inside the subkernel. while not isinstance(schedule[idx], ReturnFromKernel): - subkernel_item = schedule[idx] - subkernel_schedule.append(subkernel_item) - if isinstance(subkernel_item, RunInstruction): - insn = kernel.id_to_insn[subkernel_item.insn_id] - subkernel_defs.update( - filter_temporaries( - kernel, get_def_set(insn))) - subkernel_uses.update( - filter_temporaries( - kernel, get_use_set(insn))) + subkernel_schedule.append(schedule[idx]) idx += 1 + subkernel_defs, subkernel_uses = \ + get_temporaries_defined_and_used_in_subrange( + kernel, schedule, start_idx + 1, idx - 1) + from loopy.kernel.data import temp_var_scope # Filter out temporaries that are global. subkernel_globals = set( @@ -498,11 +514,9 @@ def restore_and_save_temporaries(kernel): # Add new arguments. sched_item = sched_item.copy( - extra_args=sorted(subkernel_globals - | set(new_temporaries[tv].name - for tv in tvals_to_load | tvals_to_spill))) - - # }}} + extra_args=sched_item.extra_args + + sorted(new_temporaries[tv].name + for tv in tvals_to_load | tvals_to_spill)) # {{{ Add all the loads and spills. @@ -619,6 +633,27 @@ def restore_and_save_temporaries(kernel): return kernel +def add_extra_args_to_schedule(kernel): + new_schedule = [] + + from loopy.schedule import CallKernel + from loopy.kernel.data import temp_var_scope + + block_bounds = get_block_boundaries(kernel.schedule) + for idx, sched_item in enumerate(kernel.schedule): + if isinstance(sched_item, CallKernel): + defs, uses = get_temporaries_defined_and_used_in_subrange( + kernel, kernel.schedule, idx + 1, block_bounds[idx] - 1) + # Filter out temporaries that are global. + extra_args = (tv for tv in defs | uses if + kernel.temporary_variables[tv].scope == temp_var_scope.GLOBAL) + new_schedule.append(sched_item.copy(extra_args=sorted(extra_args))) + else: + new_schedule.append(sched_item) + + return kernel.copy(schedule=new_schedule) + + def map_schedule_onto_host_or_device_impl(kernel): from loopy.schedule import ( RunInstruction, EnterLoop, LeaveLoop, Barrier,