Skip to content
Snippets Groups Projects
Commit 11d87fcc authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

buffer_array caching

parent a336b851
No related branches found
No related tags found
No related merge requests found
...@@ -30,11 +30,14 @@ from loopy.symbolic import (get_dependencies, ...@@ -30,11 +30,14 @@ from loopy.symbolic import (get_dependencies,
SubstitutionMapper) SubstitutionMapper)
from pymbolic.mapper.substitutor import make_subst_func from pymbolic.mapper.substitutor import make_subst_func
from pytools.persistent_dict import PersistentDict from pytools.persistent_dict import PersistentDict
from loopy.tools import LoopyKeyBuilder from loopy.tools import LoopyKeyBuilder, PymbolicExpressionHashWrapper
from loopy.version import DATA_MODEL_VERSION from loopy.version import DATA_MODEL_VERSION
from pymbolic import var from pymbolic import var
import logging
logger = logging.getLogger(__name__)
# {{{ replace array access # {{{ replace array access
...@@ -186,12 +189,15 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, ...@@ -186,12 +189,15 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
from loopy import CACHING_ENABLED from loopy import CACHING_ENABLED
cache_key = (kernel, var_name, tuple(buffer_inames), cache_key = (kernel, var_name, tuple(buffer_inames),
init_expression, store_expression, within, PymbolicExpressionHashWrapper(init_expression),
PymbolicExpressionHashWrapper(store_expression), within,
default_tag, temporary_is_local, fetch_bounding_box) default_tag, temporary_is_local, fetch_bounding_box)
if CACHING_ENABLED: if CACHING_ENABLED:
try: try:
return buffer_array_cache[cache_key] result = buffer_array_cache[cache_key]
logger.info("%s: buffer_array cache hit" % kernel.name)
return result
except KeyError: except KeyError:
pass pass
...@@ -437,7 +443,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, ...@@ -437,7 +443,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
from loopy import tag_inames from loopy import tag_inames
kernel = tag_inames(kernel, new_iname_to_tag) kernel = tag_inames(kernel, new_iname_to_tag)
if 0 and CACHING_ENABLED: if CACHING_ENABLED:
from loopy.preprocess import prepare_for_caching from loopy.preprocess import prepare_for_caching
buffer_array_cache[cache_key] = prepare_for_caching(kernel) buffer_array_cache[cache_key] = prepare_for_caching(kernel)
......
...@@ -96,6 +96,21 @@ class LoopyKeyBuilder(KeyBuilderBase): ...@@ -96,6 +96,21 @@ class LoopyKeyBuilder(KeyBuilderBase):
else: else:
PersistentHashWalkMapper(key_hash)(key) PersistentHashWalkMapper(key_hash)(key)
class PymbolicExpressionHashWrapper(object):
def __init__(self, expression):
self.expression = expression
def __eq__(self, other):
return (type(self) == type(other)
and self.expression == other.expression)
def __ne__(self, other):
return not self.__eq__(other)
def update_persistent_hash(self, key_hash, key_builder):
key_builder.update_for_pymbolic_expression(key_hash, self.expression)
# }}} # }}}
......
...@@ -275,6 +275,8 @@ def test_tagged(ctx_factory): ...@@ -275,6 +275,8 @@ def test_tagged(ctx_factory):
"i_inner,j_inner", "i_inner,j_inner",
]) ])
def test_matmul(ctx_factory, buffer_inames): def test_matmul(ctx_factory, buffer_inames):
logging.basicConfig(level=logging.INFO)
fortran_src = """ fortran_src = """
subroutine dgemm(m,n,l,a,b,c) subroutine dgemm(m,n,l,a,b,c)
implicit none implicit none
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment