From 3ffa79e50b3519b6ab5733adca4547807d10687f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 18 Feb 2014 15:52:26 -0600
Subject: [PATCH] Make a way to disable caching, use it in tests that check
 warning generation

---
 doc/reference.rst         |  7 +++++++
 loopy/__init__.py         | 34 ++++++++++++++++++++++++++++++++++
 loopy/codegen/__init__.py | 31 +++++++++++++++++++------------
 loopy/preprocess.py       | 31 +++++++++++++++++++------------
 loopy/schedule.py         | 20 +++++++++++++-------
 test/test_linalg.py       | 17 +++++++++--------
 test/test_loopy.py        | 15 ++++++++-------
 7 files changed, 109 insertions(+), 46 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index fab13029c..f7347c118 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -467,4 +467,11 @@ Options
 
 .. autofunction:: set_options
 
+Controlling caching
+-------------------
+
+.. autofunction:: set_caching_enabled
+
+.. autoclass:: CacheMode
+
 .. vim: tw=75:spell
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 8934aebc6..d16d01597 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1325,4 +1325,38 @@ def register_function_manglers(kernel, manglers):
 # }}}
 
 
+# {{{ cache control
+
+CACHING_ENABLED = True
+
+
+def set_caching_enabled(flag):
+    """Set whether :mod:`loopy` is allowed to use disk caching for its various
+    code generation stages.
+    """
+    global CACHING_ENABLED
+    CACHING_ENABLED = flag
+
+
+class CacheMode(object):
+    """A context manager for setting whether :mod:`loopy` is allowed to use
+    disk caches.
+    """
+
+    def __init__(self, new_flag):
+        self.new_flag = new_flag
+
+    def __enter__(self):
+        global CACHING_ENABLED
+        self.previous_mode = CACHING_ENABLED
+        CACHING_ENABLED = self.new_flag
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        global CACHING_ENABLED
+        CACHING_ENABLED = self.previous_mode
+        del self.previous_mode
+
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 1c9eb386a..870c9611e 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -341,18 +341,25 @@ def generate_code(kernel, device=None):
         raise LoopyError("cannot generate code for a kernel that has not been "
                 "scheduled")
 
-    if device is not None:
-        device_id = device.persistent_unique_id
-    else:
-        device_id = None
-
-    code_gen_cache_key = (kernel, device_id)
-    try:
-        result = code_gen_cache[code_gen_cache_key]
-        logger.info("%s: code generation cache hit" % kernel.name)
-        return result
-    except KeyError:
-        pass
+    # {{{ cache retrieval
+
+    from loopy import CACHING_ENABLED
+
+    if CACHING_ENABLED:
+        if device is not None:
+            device_id = device.persistent_unique_id
+        else:
+            device_id = None
+
+        code_gen_cache_key = (kernel, device_id)
+        try:
+            result = code_gen_cache[code_gen_cache_key]
+            logger.info("%s: code generation cache hit" % kernel.name)
+            return result
+        except KeyError:
+            pass
+
+    # }}}
 
     from loopy.preprocess import infer_unknown_types
     kernel = infer_unknown_types(kernel, expect_completion=True)
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index ab4a0734b..1ca0676a1 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -1076,18 +1076,24 @@ def preprocess_kernel(kernel, device=None):
         raise LoopyError("cannot re-preprocess an already preprocessed "
                 "kernel")
 
-    if device is not None:
-        device_id = device.persistent_unique_id
-    else:
-        device_id = None
+    # {{{ cache retrieval
+
+    from loopy import CACHING_ENABLED
+    if CACHING_ENABLED:
+        if device is not None:
+            device_id = device.persistent_unique_id
+        else:
+            device_id = None
 
-    pp_cache_key = (kernel, device_id)
-    try:
-        result = preprocess_cache[pp_cache_key]
-        logger.info("%s: preprocess cache hit" % kernel.name)
-        return result
-    except KeyError:
-        pass
+        pp_cache_key = (kernel, device_id)
+        try:
+            result = preprocess_cache[pp_cache_key]
+            logger.info("%s: preprocess cache hit" % kernel.name)
+            return result
+        except KeyError:
+            pass
+
+    # }}}
 
     logger.info("%s: preprocess start" % kernel.name)
 
@@ -1135,7 +1141,8 @@ def preprocess_kernel(kernel, device=None):
     kernel = kernel.copy(
             state=kernel_state.PREPROCESSED)
 
-    preprocess_cache[pp_cache_key] = kernel
+    if CACHING_ENABLED:
+        preprocess_cache[pp_cache_key] = kernel
 
     return kernel
 
diff --git a/loopy/schedule.py b/loopy/schedule.py
index fa81ae756..d9253dea3 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -1111,15 +1111,21 @@ schedule_cache = PersistentDict("loopy-schedule-cache-v2-"+VERSION_TEXT,
 
 
 def get_one_scheduled_kernel(kernel):
+    from loopy import CACHING_ENABLED
 
     sched_cache_key = kernel
-    try:
-        result, ambiguous = schedule_cache[sched_cache_key]
+    from_cache = False
 
-        logger.info("%s: schedule cache hit" % kernel.name)
-        from_cache = True
-    except KeyError:
-        from_cache = False
+    if CACHING_ENABLED:
+        try:
+            result, ambiguous = schedule_cache[sched_cache_key]
+
+            logger.info("%s: schedule cache hit" % kernel.name)
+            from_cache = True
+        except KeyError:
+            pass
+
+    if not from_cache:
         ambiguous = False
 
         kernel_count = 0
@@ -1142,7 +1148,7 @@ def get_one_scheduled_kernel(kernel):
                 "schedule found, ignoring", LoopyWarning,
                 stacklevel=2)
 
-    if not from_cache:
+    if CACHING_ENABLED and not from_cache:
         schedule_cache[sched_cache_key] = result, ambiguous
 
     return result
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 47c7600d6..cfc3df7da 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -530,14 +530,15 @@ def test_ilp_race_matmul(ctx_factory):
     knl = lp.split_iname(knl, "k", 2)
     knl = lp.add_prefetch(knl, 'b', ["k_inner"])
 
-    from loopy.diagnostic import WriteRaceConditionWarning
-    from warnings import catch_warnings
-    with catch_warnings(record=True) as warn_list:
-        knl = lp.preprocess_kernel(knl)
-        list(lp.generate_loop_schedules(knl))
-
-        assert any(isinstance(w.message, WriteRaceConditionWarning)
-                for w in warn_list)
+    with lp.CacheMode(False):
+        from loopy.diagnostic import WriteRaceConditionWarning
+        from warnings import catch_warnings
+        with catch_warnings(record=True) as warn_list:
+            knl = lp.preprocess_kernel(knl)
+            list(lp.generate_loop_schedules(knl))
+
+            assert any(isinstance(w.message, WriteRaceConditionWarning)
+                    for w in warn_list)
 
 
 def test_fancy_matrix_mul(ctx_factory):
diff --git a/test/test_loopy.py b/test/test_loopy.py
index a3f07f299..94bd89aca 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -933,13 +933,14 @@ def test_ilp_write_race_detection_global(ctx_factory):
 
     knl = lp.preprocess_kernel(knl, ctx.devices[0])
 
-    from loopy.diagnostic import WriteRaceConditionWarning
-    from warnings import catch_warnings
-    with catch_warnings(record=True) as warn_list:
-        list(lp.generate_loop_schedules(knl))
-
-        assert any(isinstance(w.message, WriteRaceConditionWarning)
-                for w in warn_list)
+    with lp.CacheMode(False):
+        from loopy.diagnostic import WriteRaceConditionWarning
+        from warnings import catch_warnings
+        with catch_warnings(record=True) as warn_list:
+            list(lp.generate_loop_schedules(knl))
+
+            assert any(isinstance(w.message, WriteRaceConditionWarning)
+                    for w in warn_list)
 
 
 def test_ilp_write_race_avoidance_local(ctx_factory):
-- 
GitLab