diff --git a/loopy/transform/save.py b/loopy/transform/save.py index 8706bc4da70b94ad678f07158e0a0f648fdd0030..65649a4cc0e3067cdb13646b2f40bbefc414ce07 100644 --- a/loopy/transform/save.py +++ b/loopy/transform/save.py @@ -26,7 +26,7 @@ THE SOFTWARE. from loopy.diagnostic import LoopyError import loopy as lp -from loopy.kernel.data import auto +from loopy.kernel.data import auto, temp_var_scope from pytools import memoize_method, Record from loopy.schedule import ( EnterLoop, LeaveLoop, RunInstruction, @@ -217,7 +217,7 @@ class TemporarySaver(object): @memoize_method def as_variable(self): temporary = self.orig_temporary - from loopy.kernel.data import TemporaryVariable, temp_var_scope + from loopy.kernel.data import TemporaryVariable return TemporaryVariable( name=self.name, dtype=temporary.dtype, @@ -245,11 +245,15 @@ class TemporarySaver(object): def auto_promote_temporary(self, temporary_name): temporary = self.kernel.temporary_variables[temporary_name] - from loopy.kernel.data import temp_var_scope if temporary.scope == temp_var_scope.GLOBAL: # Nothing to be done for global temporaries (I hope) return None + if temporary.initializer is not None: + # Temporaries with initializers do not need saving/reloading - the + # code generation takes care of emitting the initializers. + return None + if temporary.base_storage is not None: raise ValueError( "Cannot promote temporaries with base_storage to global") diff --git a/test/test_loopy.py b/test/test_loopy.py index d208793f5bde68e3e56fa3aaef9b84cc2f4a8b8c..79bf52237cbe2d69807aae99d199e59f8a60d922 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1708,6 +1708,31 @@ def test_temp_initializer(ctx_factory, src_order, tmp_order): assert np.array_equal(a, a2) +def test_const_temp_with_initializer_not_saved(): + knl = lp.make_kernel( + "{[i]: 0<=i<10}", + """ + ... gbarrier + out[i] = tmp[i] + """, + [ + lp.TemporaryVariable("tmp", + initializer=np.arange(10), + shape=lp.auto, + scope=lp.temp_var_scope.PRIVATE, + read_only=True), + "..." + ], + seq_dependencies=True) + + knl = lp.preprocess_kernel(knl) + knl = lp.get_one_scheduled_kernel(knl) + knl = lp.save_and_reload_temporaries(knl) + + # This ensures no save slot was added. + assert len(knl.temporary_variables) == 1 + + def test_header_extract(): knl = lp.make_kernel('{[k]: 0<=k<n}}', """