diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py index a613e4882ecb916d9088173851b86e5461325c57..948c419c7989174d0824c342981f88b6fa4b8e6b 100644 --- a/loopy/codegen/control.py +++ b/loopy/codegen/control.py @@ -50,7 +50,7 @@ def get_admissible_conditional_inames_for(kernel, sched_index): if not has_barrier or not isinstance(tag, LocalIndexTag): result.add(iname) - return result + return frozenset(result) def generate_code_for_sched_index(kernel, sched_index, codegen_state): @@ -104,37 +104,6 @@ def generate_code_for_sched_index(kernel, sched_index, codegen_state): % type(sched_item)) -def remove_inames_for_shared_hw_axes(kernel, cond_inames): - """ - See if cond_inames contains references to two (or more) inames that - boil down to the same tag. If so, exclude them. (We shouldn't be writing - conditionals for such inames because we would be implicitly restricting - the other inames as well.) - """ - - tag_key_uses = {} - - from loopy.kernel.data import HardwareParallelTag - - for iname in cond_inames: - tag = kernel.iname_to_tag.get(iname) - - if isinstance(tag, HardwareParallelTag): - tag_key_uses.setdefault(tag.key, []).append(iname) - - multi_use_keys = set( - key for key, user_inames in six.iteritems(tag_key_uses) - if len(user_inames) > 1) - - multi_use_inames = set() - for iname in cond_inames: - tag = kernel.iname_to_tag.get(iname) - if isinstance(tag, HardwareParallelTag) and tag.key in multi_use_keys: - multi_use_inames.add(iname) - - return frozenset(cond_inames - multi_use_inames) - - def get_required_predicates(kernel, sched_index): result = None for _, sched_item in generate_sub_sched_items(kernel.schedule, sched_index): @@ -154,6 +123,25 @@ def get_required_predicates(kernel, sched_index): return result +def group_by(l, key, merge): + if not l: + return l + + result = [] + previous = l[0] + + for item in l[1:]: + if key(previous) == key(item): + previous = merge(previous, item) + + else: + result.append(previous) + previous = item + + result.append(previous) + return result + + def build_loop_nest(kernel, sched_index, codegen_state): # Most of the complexity of this function goes towards finding groups of # instructions that can be nested inside a shared conditional. @@ -164,26 +152,29 @@ def build_loop_nest(kernel, sched_index, codegen_state): my_sched_indices = [] - while sched_index < len(kernel.schedule): - sched_item = kernel.schedule[sched_index] + i = sched_index + while i < len(kernel.schedule): + sched_item = kernel.schedule[i] if isinstance(sched_item, LeaveLoop): break - my_sched_indices.append(sched_index) + my_sched_indices.append(i) if isinstance(sched_item, EnterLoop): - _, sched_index = gather_schedule_subloop( - kernel.schedule, sched_index) + _, i = gather_schedule_subloop( + kernel.schedule, i) elif isinstance(sched_item, Barrier): - sched_index += 1 + i += 1 elif isinstance(sched_item, RunInstruction): - sched_index += 1 + i += 1 else: raise RuntimeError("unexpected schedule item type: %s" % type(sched_item)) + del i + # }}} # {{{ pass 2: find admissible conditional inames for each sibling schedule item @@ -195,16 +186,32 @@ def build_loop_nest(kernel, sched_index, codegen_state): .. attribute:: schedule_index .. attribute:: admissible_cond_inames .. attribute:: required_predicates + .. attribute:: used_inames_within """ + from loopy.schedule import find_used_inames_within sched_index_info_entries = [ ScheduleIndexInfo( - schedule_index=i, + schedule_indices=[i], admissible_cond_inames=( get_admissible_conditional_inames_for(kernel, i)), - required_predicates=get_required_predicates(kernel, i) + required_predicates=get_required_predicates(kernel, i), + used_inames_within=find_used_inames_within(kernel, i) ) - for i in my_sched_indices] + for i in my_sched_indices + ] + + sched_index_info_entries = group_by( + sched_index_info_entries, + key=lambda sii: ( + sii.admissible_cond_inames, + sii.required_predicates, + sii.used_inames_within), + merge=lambda sii1, sii2: sii1.copy( + schedule_indices=( + sii1.schedule_indices + + + sii2.schedule_indices))) # }}} @@ -236,10 +243,10 @@ def build_loop_nest(kernel, sched_index, codegen_state): def build_insn_group(sched_index_info_entries, codegen_state, done_group_lengths=set()): """ - :arg done_group_lengths: A set of group lengths (integers) that grows from - empty to include 1 and upwards with every recursive call. - It serves to prevent infinite recursion by preventing recursive - calls from doing anything about groups that are too small. + :arg done_group_lengths: A set of group lengths (integers) that grows + from empty to include the longest found group and downwards with every + recursive call. It serves to prevent infinite recursion by preventing + recursive calls from doing anything about groups that are too small. """ # The rough plan here is that build_insn_group starts out with the @@ -259,10 +266,9 @@ def build_loop_nest(kernel, sched_index, codegen_state): if not sched_index_info_entries: return [] - si_entry = sched_index_info_entries[0] - sched_index = si_entry.schedule_index - current_iname_set = si_entry.admissible_cond_inames - current_pred_set = (si_entry.required_predicates + origin_si_entry = sched_index_info_entries[0] + current_iname_set = origin_si_entry.admissible_cond_inames + current_pred_set = (origin_si_entry.required_predicates - codegen_state.implemented_predicates) # {{{ grow schedule item group @@ -293,22 +299,19 @@ def build_loop_nest(kernel, sched_index, codegen_state): # {{{ see which inames are actually used in group # And only generate conditionals for those. - from loopy.schedule import find_used_inames_within used_inames = set() for sched_index_info_entry in \ sched_index_info_entries[0:candidate_group_length]: - used_inames |= find_used_inames_within(kernel, - sched_index_info_entry.schedule_index) + used_inames |= sched_index_info_entry.used_inames_within # }}} - only_unshared_inames = remove_inames_for_shared_hw_axes(kernel, + only_unshared_inames = kernel.remove_inames_for_shared_hw_axes( current_iname_set & used_inames) bounds_checks = bounds_check_cache(only_unshared_inames) if (bounds_checks # found a bounds check - or bounds_checks is None # found impossible bounds check or current_pred_set or candidate_group_length == 1): # length-1 must always be an option to reach the recursion base @@ -316,6 +319,11 @@ def build_loop_nest(kernel, sched_index, codegen_state): found_hoists.append((candidate_group_length, bounds_checks, current_pred_set)) + if not bounds_checks and not current_pred_set: + # already no more checks possible, let's not waste time + # checking longer groups. + break + candidate_group_length += 1 # }}} @@ -352,13 +360,15 @@ def build_loop_nest(kernel, sched_index, codegen_state): if group_length == 1: # group only contains starting schedule item def gen_code(inner_codegen_state): - inner = generate_code_for_sched_index( - kernel, sched_index, inner_codegen_state) + result = [] + for i in origin_si_entry.schedule_indices: + inner = generate_code_for_sched_index( + kernel, i, inner_codegen_state) + + if inner is not None: + result.append(inner) - if inner is None: - return [] - else: - return [inner] + return result else: # recurse with a bigger done_group_lengths diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 24588976a2971d16d58dba03a44035cbc494397a..485de9ac2716e0c5e51ba02830d9197acbfc991d 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -635,6 +635,37 @@ class LoopKernel(RecordWithoutPickling): return result + @memoize_method + def remove_inames_for_shared_hw_axes(self, cond_inames): + """ + See if cond_inames contains references to two (or more) inames that + boil down to the same tag. If so, exclude them. (We shouldn't be writing + conditionals for such inames because we would be implicitly restricting + the other inames as well.) + """ + + tag_key_uses = {} + + from loopy.kernel.data import HardwareParallelTag + + for iname in cond_inames: + tag = self.iname_to_tag.get(iname) + + if isinstance(tag, HardwareParallelTag): + tag_key_uses.setdefault(tag.key, []).append(iname) + + multi_use_keys = set( + key for key, user_inames in six.iteritems(tag_key_uses) + if len(user_inames) > 1) + + multi_use_inames = set() + for iname in cond_inames: + tag = self.iname_to_tag.get(iname) + if isinstance(tag, HardwareParallelTag) and tag.key in multi_use_keys: + multi_use_inames.add(iname) + + return frozenset(cond_inames - multi_use_inames) + # }}} # {{{ dependency wrangling