From b02329060fde472312b2ac0c38e78409cb7d1822 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Fri, 7 Apr 2017 01:44:55 -0500 Subject: [PATCH] global_barrier_order: Try hard to avoid using recursive_insn_dep_map(). --- loopy/kernel/__init__.py | 59 ++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index a7b35869..895e32da 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -839,6 +839,11 @@ class LoopKernel(ImmutableRecordWithoutPickling): unvisited = set(insn.id for insn in self.instructions) + def is_barrier(my_insn_id): + insn = self.id_to_insn[my_insn_id] + from loopy.kernel.instruction import BarrierInstruction + return isinstance(insn, BarrierInstruction) and insn.kind == "global" + while unvisited: stack = [unvisited.pop()] @@ -848,16 +853,13 @@ class LoopKernel(ImmutableRecordWithoutPickling): if top in visiting: visiting.remove(top) - from loopy.kernel.instruction import BarrierInstruction - insn = self.id_to_insn[top] - if isinstance(insn, BarrierInstruction): - if insn.kind == "global": - barriers.append(top) - if top in visited: stack.pop() continue + if is_barrier(top): + barriers.append(top) + visited.add(top) visiting.add(top) @@ -866,12 +868,47 @@ class LoopKernel(ImmutableRecordWithoutPickling): assert child not in visiting stack.append(child) + if len(barriers) == 0: + return () + # Ensure this is the only possible order. - for prev_barrier, barrier in zip(barriers, barriers[1:]): - if prev_barrier not in self.recursive_insn_dep_map()[barrier]: - raise LoopyError( - "Unordered global barriers detected: '%s', '%s'" - % (barrier, prev_barrier)) + # + # This is done by traversing back up the dependency chain starting with + # the last barrier. If we don't see all the barriers, we know there must + # be a break in the order. + + stack = [barriers[-1]] + visiting.clear() + visited.clear() + + seen_barriers = set() + + while stack: + top = stack[-1] + + if top in visiting: + visiting.remove(top) + + if top in visited: + stack.pop() + continue + + if is_barrier(top): + seen_barriers.add(top) + if len(seen_barriers) == len(barriers): + break + + visited.add(top) + visiting.add(top) + + for child in self.id_to_insn[top].depends_on: + # Check for no cycles. + stack.append(child) + + if len(seen_barriers) < len(barriers): + raise LoopyError( + "Unordered global barrier sets detected: '%s', '%s'" + % (seen_barriers, set(barriers) - seen_barriers)) return tuple(barriers) -- GitLab