From f0044292654dbf7aa1aed1e54e5233005b1824ac Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Fri, 17 Mar 2017 20:55:07 -0500
Subject: [PATCH] Save and reload: For now, make sure we only save/reload one
 representative per base_storage class (see also: #42).

---
 loopy/transform/save.py | 16 ++++++++++------
 test/test_loopy.py      | 19 +++++++++++++++++++
 2 files changed, 29 insertions(+), 6 deletions(-)

diff --git a/loopy/transform/save.py b/loopy/transform/save.py
index 1c431fa10..fa98f478d 100644
--- a/loopy/transform/save.py
+++ b/loopy/transform/save.py
@@ -298,7 +298,7 @@ class TemporarySaver(object):
         return result
 
     @memoize_method
-    def get_defining_global_barrier_pair(self, subkernel):
+    def get_enclosing_global_barrier_pair(self, subkernel):
         subkernel_start, subkernel_end = (
             self.subkernel_to_slice_indices[subkernel])
 
@@ -413,10 +413,14 @@ class TemporarySaver(object):
             assert temporary.read_only
             return None
 
-        if temporary.base_storage in self.base_storage_to_representative:
-            # XXX: Todo: Warn about multiple base_storage
-            #repr = self.base_storage_to_representative[temporary.base_storage]
-            pass
+        base_storage_conflict = (
+            self.base_storage_to_representative.get(
+                temporary.base_storage, temporary) is not temporary)
+
+        if base_storage_conflict:
+            raise NotImplementedError(
+                "tried to save/reload multiple temporaries with the "
+                "same base_storage; this is currently not supported")
 
         hw_dims, hw_tags = self.get_hw_axis_sizes_and_tags_for_save_slot(temporary)
         non_hw_dims = temporary.shape
@@ -492,7 +496,7 @@ class TemporarySaver(object):
             depends_on = frozenset()
             update_deps = accessing_insns_in_subkernel
 
-        pre_barrier, post_barrier = self.get_defining_global_barrier_pair(subkernel)
+        pre_barrier, post_barrier = self.get_enclosing_global_barrier_pair(subkernel)
 
         if pre_barrier is not None:
             depends_on |= set([pre_barrier])
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 1d1450fc0..19a001084 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2161,6 +2161,25 @@ def test_global_barrier_error_if_unordered():
         knl.global_barrier_order
 
 
+def test_multi_base_storage_save_and_reload_not_supported():
+    # FIXME: This ought to work, change the test when it does.
+    knl = lp.make_kernel("{[i]: 0<=i<10}",
+        """
+        <>a[0] = 1
+        <>b[0] = 2
+        ... gbarrier
+        out = a[0] + b[0]
+        """,
+        seq_dependencies=True)
+
+    knl = lp.alias_temporaries(knl, ("a", "b"), synchronize_for_exclusive_use=False)
+    knl = lp.preprocess_kernel(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+
+    with pytest.raises(NotImplementedError):
+        lp.save_and_reload_temporaries(knl)
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab