diff --git a/loopy/schedule.py b/loopy/schedule.py index 55a1c6aabe5baff6db52cee2ad748e693000189f..1d0dc1221a8280d73cdd03858bb72404c7c79afb 100644 --- a/loopy/schedule.py +++ b/loopy/schedule.py @@ -171,7 +171,7 @@ def find_used_inames_within(kernel, sched_index): return result -def loop_nest_map(kernel): +def find_loop_nest_map(kernel): """Returns a dictionary mapping inames to other inames that are always nested around them. """ @@ -212,7 +212,7 @@ def loop_nest_map(kernel): return result -def loop_insn_dep_map(kernel): +def find_loop_insn_dep_map(kernel, loop_nest_map): """Returns a dictionary mapping inames to other instruction ids that need to be scheduled before the iname should be eligible for scheduling. """ @@ -230,10 +230,27 @@ def loop_insn_dep_map(kernel): dep_insn_inames = kernel.insn_inames(dep_insn) if iname in dep_insn_inames: - # Nothing to be learened, dependency is in loop over iname. + # Nothing to be learned, dependency is in loop over iname + # already. continue - result.setdefault(iname, set()).add(dep_insn_id) + # To make sure dep_insn belongs outside of iname, we must prove + # (via loop_nest_map) that all inames that dep_insn will be + # executed inside are nested *around* iname. + if not dep_insn_inames <= loop_nest_map[iname]: + continue + + iname_dep = result.setdefault(iname, set()) + if dep_insn_id not in iname_dep: + logger.debug("{knl}: loop dependency map: iname '{iname}' " + "depends on '{dep_insn}' via '{insn}'" + .format( + knl=kernel.name, + iname=iname, + dep_insn=dep_insn_id, + insn=insn.id)) + + iname_dep.add(dep_insn_id) return result @@ -1226,10 +1243,11 @@ def generate_loop_schedules(kernel, debug_args={}): iname for iname in kernel.all_inames() if isinstance(kernel.iname_to_tag.get(iname), ParallelTag)) + loop_nest_map = find_loop_nest_map(kernel) sched_state = SchedulerState( kernel=kernel, - loop_nest_map=loop_nest_map(kernel), - loop_insn_dep_map=loop_insn_dep_map(kernel), + loop_nest_map=loop_nest_map, + loop_insn_dep_map=find_loop_insn_dep_map(kernel, loop_nest_map), breakable_inames=ilp_inames, ilp_inames=ilp_inames, vec_inames=vec_inames,