Skip to content
Snippets Groups Projects
Commit a336b851 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge bodge:src/loopy

parents 4c292b1e b99c1622
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,9 @@ from loopy.symbolic import (get_dependencies, ...@@ -29,6 +29,9 @@ from loopy.symbolic import (get_dependencies,
RuleAwareIdentityMapper, SubstitutionRuleMappingContext, RuleAwareIdentityMapper, SubstitutionRuleMappingContext,
SubstitutionMapper) SubstitutionMapper)
from pymbolic.mapper.substitutor import make_subst_func from pymbolic.mapper.substitutor import make_subst_func
from pytools.persistent_dict import PersistentDict
from loopy.tools import LoopyKeyBuilder
from loopy.version import DATA_MODEL_VERSION
from pymbolic import var from pymbolic import var
...@@ -117,6 +120,11 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper): ...@@ -117,6 +120,11 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper):
# }}} # }}}
buffer_array_cache = PersistentDict("loopy-buffer-array-cachee"+DATA_MODEL_VERSION,
key_builder=LoopyKeyBuilder())
# Adding an argument? also add something to the cache_key below.
def buffer_array(kernel, var_name, buffer_inames, init_expression=None, def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
store_expression=None, within=None, default_tag="l.auto", store_expression=None, within=None, default_tag="l.auto",
temporary_is_local=None, fetch_bounding_box=False): temporary_is_local=None, fetch_bounding_box=False):
...@@ -173,6 +181,22 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, ...@@ -173,6 +181,22 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
# }}} # }}}
# {{{ caching
from loopy import CACHING_ENABLED
cache_key = (kernel, var_name, tuple(buffer_inames),
init_expression, store_expression, within,
default_tag, temporary_is_local, fetch_bounding_box)
if CACHING_ENABLED:
try:
return buffer_array_cache[cache_key]
except KeyError:
pass
# }}}
var_name_gen = kernel.get_var_name_generator() var_name_gen = kernel.get_var_name_generator()
within_inames = set() within_inames = set()
...@@ -413,6 +437,10 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, ...@@ -413,6 +437,10 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
from loopy import tag_inames from loopy import tag_inames
kernel = tag_inames(kernel, new_iname_to_tag) kernel = tag_inames(kernel, new_iname_to_tag)
if 0 and CACHING_ENABLED:
from loopy.preprocess import prepare_for_caching
buffer_array_cache[cache_key] = prepare_for_caching(kernel)
return kernel return kernel
# vim: foldmethod=marker # vim: foldmethod=marker
...@@ -94,11 +94,21 @@ class MatchExpressionBase(object): ...@@ -94,11 +94,21 @@ class MatchExpressionBase(object):
def __call__(self, kernel, matchable): def __call__(self, kernel, matchable):
raise NotImplementedError raise NotImplementedError
def __ne__(self, other):
return not self.__eq__(other)
class AllMatchExpression(MatchExpressionBase): class AllMatchExpression(MatchExpressionBase):
def __call__(self, kernel, matchable): def __call__(self, kernel, matchable):
return True return True
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "all_match_expr")
def __eq__(self, other):
return (type(self) == type(other))
class AndMatchExpression(MatchExpressionBase): class AndMatchExpression(MatchExpressionBase):
def __init__(self, children): def __init__(self, children):
...@@ -110,6 +120,14 @@ class AndMatchExpression(MatchExpressionBase): ...@@ -110,6 +120,14 @@ class AndMatchExpression(MatchExpressionBase):
def __str__(self): def __str__(self):
return "(%s)" % (" and ".join(str(ch) for ch in self.children)) return "(%s)" % (" and ".join(str(ch) for ch in self.children))
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "and_match_expr")
key_builder.rec(key_hash, self.children)
def __eq__(self, other):
return (type(self) == type(other)
and self.children == other.children)
class OrMatchExpression(MatchExpressionBase): class OrMatchExpression(MatchExpressionBase):
def __init__(self, children): def __init__(self, children):
...@@ -121,6 +139,14 @@ class OrMatchExpression(MatchExpressionBase): ...@@ -121,6 +139,14 @@ class OrMatchExpression(MatchExpressionBase):
def __str__(self): def __str__(self):
return "(%s)" % (" or ".join(str(ch) for ch in self.children)) return "(%s)" % (" or ".join(str(ch) for ch in self.children))
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "or_match_expr")
key_builder.rec(key_hash, self.children)
def __eq__(self, other):
return (type(self) == type(other)
and self.children == other.children)
class NotMatchExpression(MatchExpressionBase): class NotMatchExpression(MatchExpressionBase):
def __init__(self, child): def __init__(self, child):
...@@ -132,6 +158,14 @@ class NotMatchExpression(MatchExpressionBase): ...@@ -132,6 +158,14 @@ class NotMatchExpression(MatchExpressionBase):
def __str__(self): def __str__(self):
return "(not %s)" % str(self.child) return "(not %s)" % str(self.child)
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "not_match_expr")
key_builder.rec(key_hash, self.child)
def __eq__(self, other):
return (type(self) == type(other)
and self.child == other.child)
class GlobMatchExpressionBase(MatchExpressionBase): class GlobMatchExpressionBase(MatchExpressionBase):
def __init__(self, glob): def __init__(self, glob):
...@@ -146,6 +180,14 @@ class GlobMatchExpressionBase(MatchExpressionBase): ...@@ -146,6 +180,14 @@ class GlobMatchExpressionBase(MatchExpressionBase):
descr = descr[:descr.find("Match")] descr = descr[:descr.find("Match")]
return descr.lower() + ":" + self.glob return descr.lower() + ":" + self.glob
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, type(self).__name__)
key_builder.rec(key_hash, self.glob)
def __eq__(self, other):
return (type(self) == type(other)
and self.glob == other.glob)
class IdMatchExpression(GlobMatchExpressionBase): class IdMatchExpression(GlobMatchExpressionBase):
def __call__(self, kernel, matchable): def __call__(self, kernel, matchable):
...@@ -284,18 +326,31 @@ def parse_match(expr_str): ...@@ -284,18 +326,31 @@ def parse_match(expr_str):
# {{{ stack match objects # {{{ stack match objects
class StackMatchComponent(object): class StackMatchComponent(object):
pass def __ne__(self, other):
return not self.__eq__(other)
class StackAllMatchComponent(StackMatchComponent): class StackAllMatchComponent(StackMatchComponent):
def __call__(self, kernel, stack): def __call__(self, kernel, stack):
return True return True
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "all_match")
def __eq__(self, other):
return (type(self) == type(other))
class StackBottomMatchComponent(StackMatchComponent): class StackBottomMatchComponent(StackMatchComponent):
def __call__(self, kernel, stack): def __call__(self, kernel, stack):
return not stack return not stack
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "bottom_match")
def __eq__(self, other):
return (type(self) == type(other))
class StackItemMatchComponent(StackMatchComponent): class StackItemMatchComponent(StackMatchComponent):
def __init__(self, match_expr, inner_match): def __init__(self, match_expr, inner_match):
...@@ -312,6 +367,16 @@ class StackItemMatchComponent(StackMatchComponent): ...@@ -312,6 +367,16 @@ class StackItemMatchComponent(StackMatchComponent):
return self.inner_match(kernel, stack[1:]) return self.inner_match(kernel, stack[1:])
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, "item_match")
key_builder.rec(key_hash, self.match_expr)
key_builder.rec(key_hash, self.inner_match)
def __eq__(self, other):
return (type(self) == type(other)
and self.match_expr == other.match_expr
and self.inner_match == other.inner_match)
class StackWildcardMatchComponent(StackMatchComponent): class StackWildcardMatchComponent(StackMatchComponent):
def __init__(self, inner_match): def __init__(self, inner_match):
...@@ -348,6 +413,18 @@ class StackMatch(object): ...@@ -348,6 +413,18 @@ class StackMatch(object):
def __init__(self, root_component): def __init__(self, root_component):
self.root_component = root_component self.root_component = root_component
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, self.root_component)
def __eq__(self, other):
return (
type(self) == type(other)
and
self.root_component == other.root_component)
def __ne__(self, other):
return not self.__eq__(other)
def __call__(self, kernel, insn, rule_stack): def __call__(self, kernel, insn, rule_stack):
""" """
:arg rule_stack: a tuple of (name, tags) rule invocation, outermost first :arg rule_stack: a tuple of (name, tags) rule invocation, outermost first
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment