From 11d87fcc27309e9af92ca33bf221f96aff4848a5 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 21 Aug 2015 11:05:47 -0500
Subject: [PATCH] buffer_array caching

---
 loopy/buffer.py      | 14 ++++++++++----
 loopy/tools.py       | 15 +++++++++++++++
 test/test_fortran.py |  2 ++
 3 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/loopy/buffer.py b/loopy/buffer.py
index 2c5953950..1e6a137b5 100644
--- a/loopy/buffer.py
+++ b/loopy/buffer.py
@@ -30,11 +30,14 @@ from loopy.symbolic import (get_dependencies,
         SubstitutionMapper)
 from pymbolic.mapper.substitutor import make_subst_func
 from pytools.persistent_dict import PersistentDict
-from loopy.tools import LoopyKeyBuilder
+from loopy.tools import LoopyKeyBuilder, PymbolicExpressionHashWrapper
 from loopy.version import DATA_MODEL_VERSION
 
 from pymbolic import var
 
+import logging
+logger = logging.getLogger(__name__)
+
 
 # {{{ replace array access
 
@@ -186,12 +189,15 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
     from loopy import CACHING_ENABLED
 
     cache_key = (kernel, var_name, tuple(buffer_inames),
-            init_expression, store_expression, within,
+            PymbolicExpressionHashWrapper(init_expression),
+            PymbolicExpressionHashWrapper(store_expression), within,
             default_tag, temporary_is_local, fetch_bounding_box)
 
     if CACHING_ENABLED:
         try:
-            return buffer_array_cache[cache_key]
+            result = buffer_array_cache[cache_key]
+            logger.info("%s: buffer_array cache hit" % kernel.name)
+            return result
         except KeyError:
             pass
 
@@ -437,7 +443,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
     from loopy import tag_inames
     kernel = tag_inames(kernel, new_iname_to_tag)
 
-    if 0 and CACHING_ENABLED:
+    if CACHING_ENABLED:
         from loopy.preprocess import prepare_for_caching
         buffer_array_cache[cache_key] = prepare_for_caching(kernel)
 
diff --git a/loopy/tools.py b/loopy/tools.py
index 861d15568..55b177bda 100644
--- a/loopy/tools.py
+++ b/loopy/tools.py
@@ -96,6 +96,21 @@ class LoopyKeyBuilder(KeyBuilderBase):
         else:
             PersistentHashWalkMapper(key_hash)(key)
 
+
+class PymbolicExpressionHashWrapper(object):
+    def __init__(self, expression):
+        self.expression = expression
+
+    def __eq__(self, other):
+        return (type(self) == type(other)
+                and self.expression == other.expression)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def update_persistent_hash(self, key_hash, key_builder):
+        key_builder.update_for_pymbolic_expression(key_hash, self.expression)
+
 # }}}
 
 
diff --git a/test/test_fortran.py b/test/test_fortran.py
index 212233ebc..a5b1b830b 100644
--- a/test/test_fortran.py
+++ b/test/test_fortran.py
@@ -275,6 +275,8 @@ def test_tagged(ctx_factory):
     "i_inner,j_inner",
     ])
 def test_matmul(ctx_factory, buffer_inames):
+    logging.basicConfig(level=logging.INFO)
+
     fortran_src = """
         subroutine dgemm(m,n,l,a,b,c)
           implicit none
-- 
GitLab