From 11d87fcc27309e9af92ca33bf221f96aff4848a5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Fri, 21 Aug 2015 11:05:47 -0500 Subject: [PATCH] buffer_array caching --- loopy/buffer.py | 14 ++++++++++---- loopy/tools.py | 15 +++++++++++++++ test/test_fortran.py | 2 ++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/loopy/buffer.py b/loopy/buffer.py index 2c5953950..1e6a137b5 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -30,11 +30,14 @@ from loopy.symbolic import (get_dependencies, SubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func from pytools.persistent_dict import PersistentDict -from loopy.tools import LoopyKeyBuilder +from loopy.tools import LoopyKeyBuilder, PymbolicExpressionHashWrapper from loopy.version import DATA_MODEL_VERSION from pymbolic import var +import logging +logger = logging.getLogger(__name__) + # {{{ replace array access @@ -186,12 +189,15 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, from loopy import CACHING_ENABLED cache_key = (kernel, var_name, tuple(buffer_inames), - init_expression, store_expression, within, + PymbolicExpressionHashWrapper(init_expression), + PymbolicExpressionHashWrapper(store_expression), within, default_tag, temporary_is_local, fetch_bounding_box) if CACHING_ENABLED: try: - return buffer_array_cache[cache_key] + result = buffer_array_cache[cache_key] + logger.info("%s: buffer_array cache hit" % kernel.name) + return result except KeyError: pass @@ -437,7 +443,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, from loopy import tag_inames kernel = tag_inames(kernel, new_iname_to_tag) - if 0 and CACHING_ENABLED: + if CACHING_ENABLED: from loopy.preprocess import prepare_for_caching buffer_array_cache[cache_key] = prepare_for_caching(kernel) diff --git a/loopy/tools.py b/loopy/tools.py index 861d15568..55b177bda 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -96,6 +96,21 @@ class LoopyKeyBuilder(KeyBuilderBase): else: PersistentHashWalkMapper(key_hash)(key) + +class PymbolicExpressionHashWrapper(object): + def __init__(self, expression): + self.expression = expression + + def __eq__(self, other): + return (type(self) == type(other) + and self.expression == other.expression) + + def __ne__(self, other): + return not self.__eq__(other) + + def update_persistent_hash(self, key_hash, key_builder): + key_builder.update_for_pymbolic_expression(key_hash, self.expression) + # }}} diff --git a/test/test_fortran.py b/test/test_fortran.py index 212233ebc..a5b1b830b 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -275,6 +275,8 @@ def test_tagged(ctx_factory): "i_inner,j_inner", ]) def test_matmul(ctx_factory, buffer_inames): + logging.basicConfig(level=logging.INFO) + fortran_src = """ subroutine dgemm(m,n,l,a,b,c) implicit none -- GitLab