diff --git a/loopy/buffer.py b/loopy/buffer.py index 2c59539500a41f655d0e83e7e0520d280aed7a7e..1e6a137b551645a25145ddaaeb8eea40eea554af 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -30,11 +30,14 @@ from loopy.symbolic import (get_dependencies, SubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func from pytools.persistent_dict import PersistentDict -from loopy.tools import LoopyKeyBuilder +from loopy.tools import LoopyKeyBuilder, PymbolicExpressionHashWrapper from loopy.version import DATA_MODEL_VERSION from pymbolic import var +import logging +logger = logging.getLogger(__name__) + # {{{ replace array access @@ -186,12 +189,15 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, from loopy import CACHING_ENABLED cache_key = (kernel, var_name, tuple(buffer_inames), - init_expression, store_expression, within, + PymbolicExpressionHashWrapper(init_expression), + PymbolicExpressionHashWrapper(store_expression), within, default_tag, temporary_is_local, fetch_bounding_box) if CACHING_ENABLED: try: - return buffer_array_cache[cache_key] + result = buffer_array_cache[cache_key] + logger.info("%s: buffer_array cache hit" % kernel.name) + return result except KeyError: pass @@ -437,7 +443,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, from loopy import tag_inames kernel = tag_inames(kernel, new_iname_to_tag) - if 0 and CACHING_ENABLED: + if CACHING_ENABLED: from loopy.preprocess import prepare_for_caching buffer_array_cache[cache_key] = prepare_for_caching(kernel) diff --git a/loopy/tools.py b/loopy/tools.py index 861d155686e1ffde5b480536cc5e71ad1d2841de..55b177bda4e6be03a985286fd4faf6322e257824 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -96,6 +96,21 @@ class LoopyKeyBuilder(KeyBuilderBase): else: PersistentHashWalkMapper(key_hash)(key) + +class PymbolicExpressionHashWrapper(object): + def __init__(self, expression): + self.expression = expression + + def __eq__(self, other): + return (type(self) == type(other) + and self.expression == other.expression) + + def __ne__(self, other): + return not self.__eq__(other) + + def update_persistent_hash(self, key_hash, key_builder): + key_builder.update_for_pymbolic_expression(key_hash, self.expression) + # }}} diff --git a/test/test_fortran.py b/test/test_fortran.py index 212233ebcd895a93b5e89674323b339d98c08e21..a5b1b830bc8834637d5f4c609fff8232ef7449e6 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -275,6 +275,8 @@ def test_tagged(ctx_factory): "i_inner,j_inner", ]) def test_matmul(ctx_factory, buffer_inames): + logging.basicConfig(level=logging.INFO) + fortran_src = """ subroutine dgemm(m,n,l,a,b,c) implicit none