diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py index 9ed8da938ba9ded5204d4fb9c10a9799e37cc3d6..8c51ee10ca12d80fd5ea27b541a88809a3e19248 100644 --- a/loopy/codegen/control.py +++ b/loopy/codegen/control.py @@ -156,6 +156,20 @@ def build_loop_nest(kernel, sched_index, codegen_state): # {{{ pass 3: greedily group schedule items that share admissible inames + from pytools import memoize_method + + class BoundsCheckCache: + def __init__(self, domain, impl_domain): + self.domain = domain + self.impl_domain = impl_domain + + @memoize_method + def __call__(self, check_inames): + from loopy.codegen.bounds import generate_bounds_checks + return generate_bounds_checks(self.domain, + check_inames, + codegen_state.implemented_domain) + def build_insn_group(sched_indices_and_cond_inames, codegen_state, done_group_lengths=set()): # done_group_lengths serves to prevent infinite recursion by imposing a # bigger and bigger minimum size on the group of shared inames found. @@ -170,6 +184,8 @@ def build_loop_nest(kernel, sched_index, codegen_state): # Keep growing schedule item group as long as group fulfills minimum # size requirement. + bounds_check_cache = BoundsCheckCache(kernel.domain, codegen_state.implemented_domain) + current_iname_set = cond_inames found_hoists = [] @@ -193,14 +209,12 @@ def build_loop_nest(kernel, sched_index, codegen_state): # }}} - from loopy.codegen.bounds import generate_bounds_checks only_unshared_inames = remove_inames_for_shared_hw_axes(kernel, current_iname_set & used_inames) - bounds_checks = generate_bounds_checks(kernel.domain, - remove_inames_for_shared_hw_axes(kernel, - only_unshared_inames), - codegen_state.implemented_domain) + bounds_checks = bounds_check_cache( + frozenset(remove_inames_for_shared_hw_axes(kernel, + only_unshared_inames))) if bounds_checks or candidate_group_length == 1: # length-1 must always be an option to reach the recursion base case below diff --git a/loopy/kernel.py b/loopy/kernel.py index eb7a340e49432e122c98c8504805ec38143411f2..5012f154fa218df920436b3ea7f8709c5e2a00c1 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -1285,7 +1285,7 @@ class SetOperationCacheManager: # mapping: set hash -> [(set, op, args, result)] self.cache = {} - def op(self, set, op, args): + def op(self, set, op_name, op, args): hashval = hash(set) bucket = self.cache.setdefault(hashval, []) @@ -1294,15 +1294,15 @@ class SetOperationCacheManager: return result #print op, set.get_dim_name(dim_type.set, args[0]) - result = getattr(set, op)(*args) - bucket.append((set, op, args, result)) + result = op(*args) + bucket.append((set, op_name, args, result)) return result def dim_min(self, set, *args): - return self.op(set, "dim_min", args) + return self.op(set, "dim_min", set.dim_min, args) def dim_max(self, set, *args): - return self.op(set, "dim_max", args) + return self.op(set, "dim_max", set.dim_max, args)