diff --git a/loopy/schedule/device_mapping.py b/loopy/schedule/device_mapping.py index c877f1c1751f3ab64adca800e8e8ff50557fc02e..4cd5a8e8a016f6799425091f72acb3a7b96b16e3 100644 --- a/loopy/schedule/device_mapping.py +++ b/loopy/schedule/device_mapping.py @@ -297,6 +297,7 @@ def restore_and_save_temporaries(kernel): new_temporaries = {} name_gen = kernel.get_var_name_generator() + from loopy.kernel.data import LocalIndexTag, temp_var_scope from pytools import Record class PromotedTemporary(Record): @@ -322,10 +323,10 @@ def restore_and_save_temporaries(kernel): def as_variable(self): temporary = self.orig_temporary from loopy.kernel.data import TemporaryVariable - # XXX: This needs to be marked as global. return TemporaryVariable( name=self.name, dtype=temporary.dtype, + scope=temp_var_scope.GLOBAL, shape=self.new_shape) @property @@ -333,8 +334,6 @@ def restore_and_save_temporaries(kernel): return self.shape_prefix + self.orig_temporary.shape for temporary in inter_kernel_temporaries: - from loopy.kernel.data import LocalIndexTag, temp_var_scope - temporary = kernel.temporary_variables[temporary] if temporary.scope == temp_var_scope.GLOBAL: # Nothing to be done for global temporaries (I hope) @@ -415,16 +414,23 @@ def restore_and_save_temporaries(kernel): kernel, get_use_set(insn))) idx += 1 - tvals_to_spill = subkernel_defs & live_out[idx] + # Filter out temporaries that are global. + subkernel_globals = set( + tval for tval in subkernel_defs | subkernel_uses + if kernel.temporary_variables[tval].scope == temp_var_scope.GLOBAL) + + tvals_to_spill = (subkernel_defs - subkernel_globals) & live_out[idx] # Need to load tvals_to_spill, to avoid overwriting entries that the # code doesn't touch when doing the spill. - tvals_to_load = (subkernel_uses | tvals_to_spill) & live_in[start_idx] + tvals_to_load = ((subkernel_uses - subkernel_globals) + | tvals_to_spill) & live_in[start_idx] # Add arguments. new_schedule.append( sched_item.copy(extra_args=sorted( set(new_temporaries[tval].name - for tval in tvals_to_spill | tvals_to_load)))) + for tval in tvals_to_spill | tvals_to_load) + | subkernel_globals))) import islpy as isl