diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 895e32daaf2789ff843c5725097e341af28dce50..a7b358698c0b7c1bd17fa9431dc12223ff8ded46 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -839,11 +839,6 @@ class LoopKernel(ImmutableRecordWithoutPickling): unvisited = set(insn.id for insn in self.instructions) - def is_barrier(my_insn_id): - insn = self.id_to_insn[my_insn_id] - from loopy.kernel.instruction import BarrierInstruction - return isinstance(insn, BarrierInstruction) and insn.kind == "global" - while unvisited: stack = [unvisited.pop()] @@ -853,13 +848,16 @@ class LoopKernel(ImmutableRecordWithoutPickling): if top in visiting: visiting.remove(top) + from loopy.kernel.instruction import BarrierInstruction + insn = self.id_to_insn[top] + if isinstance(insn, BarrierInstruction): + if insn.kind == "global": + barriers.append(top) + if top in visited: stack.pop() continue - if is_barrier(top): - barriers.append(top) - visited.add(top) visiting.add(top) @@ -868,47 +866,12 @@ class LoopKernel(ImmutableRecordWithoutPickling): assert child not in visiting stack.append(child) - if len(barriers) == 0: - return () - # Ensure this is the only possible order. - # - # This is done by traversing back up the dependency chain starting with - # the last barrier. If we don't see all the barriers, we know there must - # be a break in the order. - - stack = [barriers[-1]] - visiting.clear() - visited.clear() - - seen_barriers = set() - - while stack: - top = stack[-1] - - if top in visiting: - visiting.remove(top) - - if top in visited: - stack.pop() - continue - - if is_barrier(top): - seen_barriers.add(top) - if len(seen_barriers) == len(barriers): - break - - visited.add(top) - visiting.add(top) - - for child in self.id_to_insn[top].depends_on: - # Check for no cycles. - stack.append(child) - - if len(seen_barriers) < len(barriers): - raise LoopyError( - "Unordered global barrier sets detected: '%s', '%s'" - % (seen_barriers, set(barriers) - seen_barriers)) + for prev_barrier, barrier in zip(barriers, barriers[1:]): + if prev_barrier not in self.recursive_insn_dep_map()[barrier]: + raise LoopyError( + "Unordered global barriers detected: '%s', '%s'" + % (barrier, prev_barrier)) return tuple(barriers)