diff --git a/loopy/transform/save.py b/loopy/transform/save.py index 1c431fa10ab61109094e666423f43ae5906c5a65..fa98f478d7e09bd77a3dafcd0327c5d6dce737d0 100644 --- a/loopy/transform/save.py +++ b/loopy/transform/save.py @@ -298,7 +298,7 @@ class TemporarySaver(object): return result @memoize_method - def get_defining_global_barrier_pair(self, subkernel): + def get_enclosing_global_barrier_pair(self, subkernel): subkernel_start, subkernel_end = ( self.subkernel_to_slice_indices[subkernel]) @@ -413,10 +413,14 @@ class TemporarySaver(object): assert temporary.read_only return None - if temporary.base_storage in self.base_storage_to_representative: - # XXX: Todo: Warn about multiple base_storage - #repr = self.base_storage_to_representative[temporary.base_storage] - pass + base_storage_conflict = ( + self.base_storage_to_representative.get( + temporary.base_storage, temporary) is not temporary) + + if base_storage_conflict: + raise NotImplementedError( + "tried to save/reload multiple temporaries with the " + "same base_storage; this is currently not supported") hw_dims, hw_tags = self.get_hw_axis_sizes_and_tags_for_save_slot(temporary) non_hw_dims = temporary.shape @@ -492,7 +496,7 @@ class TemporarySaver(object): depends_on = frozenset() update_deps = accessing_insns_in_subkernel - pre_barrier, post_barrier = self.get_defining_global_barrier_pair(subkernel) + pre_barrier, post_barrier = self.get_enclosing_global_barrier_pair(subkernel) if pre_barrier is not None: depends_on |= set([pre_barrier]) diff --git a/test/test_loopy.py b/test/test_loopy.py index 1d1450fc0052976d621360b0262f4c1c686f8b88..19a00108449b3be6905ce10905a7f313a3194b7d 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2161,6 +2161,25 @@ def test_global_barrier_error_if_unordered(): knl.global_barrier_order +def test_multi_base_storage_save_and_reload_not_supported(): + # FIXME: This ought to work, change the test when it does. + knl = lp.make_kernel("{[i]: 0<=i<10}", + """ + <>a[0] = 1 + <>b[0] = 2 + ... gbarrier + out = a[0] + b[0] + """, + seq_dependencies=True) + + knl = lp.alias_temporaries(knl, ("a", "b"), synchronize_for_exclusive_use=False) + knl = lp.preprocess_kernel(knl) + knl = lp.get_one_scheduled_kernel(knl) + + with pytest.raises(NotImplementedError): + lp.save_and_reload_temporaries(knl) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])