diff --git a/doc/reference.rst b/doc/reference.rst index fab13029c167891df4e19894ebe4db7c1aa5de6d..f7347c118fda00fa7a6ba9e9f36a26bc83c3a077 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -467,4 +467,11 @@ Options .. autofunction:: set_options +Controlling caching +------------------- + +.. autofunction:: set_caching_enabled + +.. autoclass:: CacheMode + .. vim: tw=75:spell diff --git a/loopy/__init__.py b/loopy/__init__.py index 8934aebc633d53b80df7d013306e3e73552c19d5..d16d0159706d6ad09b817b327489ec7b153868a6 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1325,4 +1325,38 @@ def register_function_manglers(kernel, manglers): # }}} +# {{{ cache control + +CACHING_ENABLED = True + + +def set_caching_enabled(flag): + """Set whether :mod:`loopy` is allowed to use disk caching for its various + code generation stages. + """ + global CACHING_ENABLED + CACHING_ENABLED = flag + + +class CacheMode(object): + """A context manager for setting whether :mod:`loopy` is allowed to use + disk caches. + """ + + def __init__(self, new_flag): + self.new_flag = new_flag + + def __enter__(self): + global CACHING_ENABLED + self.previous_mode = CACHING_ENABLED + CACHING_ENABLED = self.new_flag + + def __exit__(self, exc_type, exc_val, exc_tb): + global CACHING_ENABLED + CACHING_ENABLED = self.previous_mode + del self.previous_mode + +# }}} + + # vim: foldmethod=marker diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 1c9eb386aa2f98c6793e624ef5393a7c07b7211c..870c9611e0c161034652d898eddcb29933a56236 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -341,18 +341,25 @@ def generate_code(kernel, device=None): raise LoopyError("cannot generate code for a kernel that has not been " "scheduled") - if device is not None: - device_id = device.persistent_unique_id - else: - device_id = None - - code_gen_cache_key = (kernel, device_id) - try: - result = code_gen_cache[code_gen_cache_key] - logger.info("%s: code generation cache hit" % kernel.name) - return result - except KeyError: - pass + # {{{ cache retrieval + + from loopy import CACHING_ENABLED + + if CACHING_ENABLED: + if device is not None: + device_id = device.persistent_unique_id + else: + device_id = None + + code_gen_cache_key = (kernel, device_id) + try: + result = code_gen_cache[code_gen_cache_key] + logger.info("%s: code generation cache hit" % kernel.name) + return result + except KeyError: + pass + + # }}} from loopy.preprocess import infer_unknown_types kernel = infer_unknown_types(kernel, expect_completion=True) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index ab4a0734b37da4cb4b15532fa2ffbf4310b93bf9..1ca0676a1686a4ab551a8f343c558944b476376c 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -1076,18 +1076,24 @@ def preprocess_kernel(kernel, device=None): raise LoopyError("cannot re-preprocess an already preprocessed " "kernel") - if device is not None: - device_id = device.persistent_unique_id - else: - device_id = None + # {{{ cache retrieval + + from loopy import CACHING_ENABLED + if CACHING_ENABLED: + if device is not None: + device_id = device.persistent_unique_id + else: + device_id = None - pp_cache_key = (kernel, device_id) - try: - result = preprocess_cache[pp_cache_key] - logger.info("%s: preprocess cache hit" % kernel.name) - return result - except KeyError: - pass + pp_cache_key = (kernel, device_id) + try: + result = preprocess_cache[pp_cache_key] + logger.info("%s: preprocess cache hit" % kernel.name) + return result + except KeyError: + pass + + # }}} logger.info("%s: preprocess start" % kernel.name) @@ -1135,7 +1141,8 @@ def preprocess_kernel(kernel, device=None): kernel = kernel.copy( state=kernel_state.PREPROCESSED) - preprocess_cache[pp_cache_key] = kernel + if CACHING_ENABLED: + preprocess_cache[pp_cache_key] = kernel return kernel diff --git a/loopy/schedule.py b/loopy/schedule.py index fa81ae756d62e2f5f35b943e8bdf0bb13e6dcf26..d9253dea3b10edaccf2e880f74e1cbe927c97afd 100644 --- a/loopy/schedule.py +++ b/loopy/schedule.py @@ -1111,15 +1111,21 @@ schedule_cache = PersistentDict("loopy-schedule-cache-v2-"+VERSION_TEXT, def get_one_scheduled_kernel(kernel): + from loopy import CACHING_ENABLED sched_cache_key = kernel - try: - result, ambiguous = schedule_cache[sched_cache_key] + from_cache = False - logger.info("%s: schedule cache hit" % kernel.name) - from_cache = True - except KeyError: - from_cache = False + if CACHING_ENABLED: + try: + result, ambiguous = schedule_cache[sched_cache_key] + + logger.info("%s: schedule cache hit" % kernel.name) + from_cache = True + except KeyError: + pass + + if not from_cache: ambiguous = False kernel_count = 0 @@ -1142,7 +1148,7 @@ def get_one_scheduled_kernel(kernel): "schedule found, ignoring", LoopyWarning, stacklevel=2) - if not from_cache: + if CACHING_ENABLED and not from_cache: schedule_cache[sched_cache_key] = result, ambiguous return result diff --git a/test/test_linalg.py b/test/test_linalg.py index 47c7600d6dc74c4468f9ca2010801589be777789..cfc3df7da4ef30a5ffb511aa8e966e1b19a690f7 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -530,14 +530,15 @@ def test_ilp_race_matmul(ctx_factory): knl = lp.split_iname(knl, "k", 2) knl = lp.add_prefetch(knl, 'b', ["k_inner"]) - from loopy.diagnostic import WriteRaceConditionWarning - from warnings import catch_warnings - with catch_warnings(record=True) as warn_list: - knl = lp.preprocess_kernel(knl) - list(lp.generate_loop_schedules(knl)) - - assert any(isinstance(w.message, WriteRaceConditionWarning) - for w in warn_list) + with lp.CacheMode(False): + from loopy.diagnostic import WriteRaceConditionWarning + from warnings import catch_warnings + with catch_warnings(record=True) as warn_list: + knl = lp.preprocess_kernel(knl) + list(lp.generate_loop_schedules(knl)) + + assert any(isinstance(w.message, WriteRaceConditionWarning) + for w in warn_list) def test_fancy_matrix_mul(ctx_factory): diff --git a/test/test_loopy.py b/test/test_loopy.py index a3f07f299bff6c9b3121e963dd1a612bd6e89725..94bd89acac7cc9e8b4f31fcebe87b3cab393b656 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -933,13 +933,14 @@ def test_ilp_write_race_detection_global(ctx_factory): knl = lp.preprocess_kernel(knl, ctx.devices[0]) - from loopy.diagnostic import WriteRaceConditionWarning - from warnings import catch_warnings - with catch_warnings(record=True) as warn_list: - list(lp.generate_loop_schedules(knl)) - - assert any(isinstance(w.message, WriteRaceConditionWarning) - for w in warn_list) + with lp.CacheMode(False): + from loopy.diagnostic import WriteRaceConditionWarning + from warnings import catch_warnings + with catch_warnings(record=True) as warn_list: + list(lp.generate_loop_schedules(knl)) + + assert any(isinstance(w.message, WriteRaceConditionWarning) + for w in warn_list) def test_ilp_write_race_avoidance_local(ctx_factory):