diff --git a/loopy/buffer.py b/loopy/buffer.py index fdc3774b29f64ba5ae8c465076f48b805836d40b..2c59539500a41f655d0e83e7e0520d280aed7a7e 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -29,6 +29,9 @@ from loopy.symbolic import (get_dependencies, RuleAwareIdentityMapper, SubstitutionRuleMappingContext, SubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func +from pytools.persistent_dict import PersistentDict +from loopy.tools import LoopyKeyBuilder +from loopy.version import DATA_MODEL_VERSION from pymbolic import var @@ -117,6 +120,11 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper): # }}} +buffer_array_cache = PersistentDict("loopy-buffer-array-cachee"+DATA_MODEL_VERSION, + key_builder=LoopyKeyBuilder()) + + +# Adding an argument? also add something to the cache_key below. def buffer_array(kernel, var_name, buffer_inames, init_expression=None, store_expression=None, within=None, default_tag="l.auto", temporary_is_local=None, fetch_bounding_box=False): @@ -173,6 +181,22 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, # }}} + # {{{ caching + + from loopy import CACHING_ENABLED + + cache_key = (kernel, var_name, tuple(buffer_inames), + init_expression, store_expression, within, + default_tag, temporary_is_local, fetch_bounding_box) + + if CACHING_ENABLED: + try: + return buffer_array_cache[cache_key] + except KeyError: + pass + + # }}} + var_name_gen = kernel.get_var_name_generator() within_inames = set() @@ -413,6 +437,10 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, from loopy import tag_inames kernel = tag_inames(kernel, new_iname_to_tag) + if 0 and CACHING_ENABLED: + from loopy.preprocess import prepare_for_caching + buffer_array_cache[cache_key] = prepare_for_caching(kernel) + return kernel # vim: foldmethod=marker diff --git a/loopy/context_matching.py b/loopy/context_matching.py index 61203ece2c38ae7beb385bd8b4758c3ce5eeeea8..a88e207002220a1be840114d71948869f566863d 100644 --- a/loopy/context_matching.py +++ b/loopy/context_matching.py @@ -94,11 +94,21 @@ class MatchExpressionBase(object): def __call__(self, kernel, matchable): raise NotImplementedError + def __ne__(self, other): + return not self.__eq__(other) + + class AllMatchExpression(MatchExpressionBase): def __call__(self, kernel, matchable): return True + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, "all_match_expr") + + def __eq__(self, other): + return (type(self) == type(other)) + class AndMatchExpression(MatchExpressionBase): def __init__(self, children): @@ -110,6 +120,14 @@ class AndMatchExpression(MatchExpressionBase): def __str__(self): return "(%s)" % (" and ".join(str(ch) for ch in self.children)) + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, "and_match_expr") + key_builder.rec(key_hash, self.children) + + def __eq__(self, other): + return (type(self) == type(other) + and self.children == other.children) + class OrMatchExpression(MatchExpressionBase): def __init__(self, children): @@ -121,6 +139,14 @@ class OrMatchExpression(MatchExpressionBase): def __str__(self): return "(%s)" % (" or ".join(str(ch) for ch in self.children)) + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, "or_match_expr") + key_builder.rec(key_hash, self.children) + + def __eq__(self, other): + return (type(self) == type(other) + and self.children == other.children) + class NotMatchExpression(MatchExpressionBase): def __init__(self, child): @@ -132,6 +158,14 @@ class NotMatchExpression(MatchExpressionBase): def __str__(self): return "(not %s)" % str(self.child) + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, "not_match_expr") + key_builder.rec(key_hash, self.child) + + def __eq__(self, other): + return (type(self) == type(other) + and self.child == other.child) + class GlobMatchExpressionBase(MatchExpressionBase): def __init__(self, glob): @@ -146,6 +180,14 @@ class GlobMatchExpressionBase(MatchExpressionBase): descr = descr[:descr.find("Match")] return descr.lower() + ":" + self.glob + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, type(self).__name__) + key_builder.rec(key_hash, self.glob) + + def __eq__(self, other): + return (type(self) == type(other) + and self.glob == other.glob) + class IdMatchExpression(GlobMatchExpressionBase): def __call__(self, kernel, matchable): @@ -284,18 +326,31 @@ def parse_match(expr_str): # {{{ stack match objects class StackMatchComponent(object): - pass + def __ne__(self, other): + return not self.__eq__(other) class StackAllMatchComponent(StackMatchComponent): def __call__(self, kernel, stack): return True + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, "all_match") + + def __eq__(self, other): + return (type(self) == type(other)) + class StackBottomMatchComponent(StackMatchComponent): def __call__(self, kernel, stack): return not stack + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, "bottom_match") + + def __eq__(self, other): + return (type(self) == type(other)) + class StackItemMatchComponent(StackMatchComponent): def __init__(self, match_expr, inner_match): @@ -312,6 +367,16 @@ class StackItemMatchComponent(StackMatchComponent): return self.inner_match(kernel, stack[1:]) + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, "item_match") + key_builder.rec(key_hash, self.match_expr) + key_builder.rec(key_hash, self.inner_match) + + def __eq__(self, other): + return (type(self) == type(other) + and self.match_expr == other.match_expr + and self.inner_match == other.inner_match) + class StackWildcardMatchComponent(StackMatchComponent): def __init__(self, inner_match): @@ -348,6 +413,18 @@ class StackMatch(object): def __init__(self, root_component): self.root_component = root_component + def update_persistent_hash(self, key_hash, key_builder): + key_builder.rec(key_hash, self.root_component) + + def __eq__(self, other): + return ( + type(self) == type(other) + and + self.root_component == other.root_component) + + def __ne__(self, other): + return not self.__eq__(other) + def __call__(self, kernel, insn, rule_stack): """ :arg rule_stack: a tuple of (name, tags) rule invocation, outermost first