diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index a7b358698c0b7c1bd17fa9431dc12223ff8ded46..895e32daaf2789ff843c5725097e341af28dce50 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -839,6 +839,11 @@ class LoopKernel(ImmutableRecordWithoutPickling): unvisited = set(insn.id for insn in self.instructions) + def is_barrier(my_insn_id): + insn = self.id_to_insn[my_insn_id] + from loopy.kernel.instruction import BarrierInstruction + return isinstance(insn, BarrierInstruction) and insn.kind == "global" + while unvisited: stack = [unvisited.pop()] @@ -848,16 +853,13 @@ class LoopKernel(ImmutableRecordWithoutPickling): if top in visiting: visiting.remove(top) - from loopy.kernel.instruction import BarrierInstruction - insn = self.id_to_insn[top] - if isinstance(insn, BarrierInstruction): - if insn.kind == "global": - barriers.append(top) - if top in visited: stack.pop() continue + if is_barrier(top): + barriers.append(top) + visited.add(top) visiting.add(top) @@ -866,12 +868,47 @@ class LoopKernel(ImmutableRecordWithoutPickling): assert child not in visiting stack.append(child) + if len(barriers) == 0: + return () + # Ensure this is the only possible order. - for prev_barrier, barrier in zip(barriers, barriers[1:]): - if prev_barrier not in self.recursive_insn_dep_map()[barrier]: - raise LoopyError( - "Unordered global barriers detected: '%s', '%s'" - % (barrier, prev_barrier)) + # + # This is done by traversing back up the dependency chain starting with + # the last barrier. If we don't see all the barriers, we know there must + # be a break in the order. + + stack = [barriers[-1]] + visiting.clear() + visited.clear() + + seen_barriers = set() + + while stack: + top = stack[-1] + + if top in visiting: + visiting.remove(top) + + if top in visited: + stack.pop() + continue + + if is_barrier(top): + seen_barriers.add(top) + if len(seen_barriers) == len(barriers): + break + + visited.add(top) + visiting.add(top) + + for child in self.id_to_insn[top].depends_on: + # Check for no cycles. + stack.append(child) + + if len(seen_barriers) < len(barriers): + raise LoopyError( + "Unordered global barrier sets detected: '%s', '%s'" + % (seen_barriers, set(barriers) - seen_barriers)) return tuple(barriers)