diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index a7b358698c0b7c1bd17fa9431dc12223ff8ded46..15084df7e4ad15849a26fa0da4787626e7e116b6 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -839,6 +839,11 @@ class LoopKernel(ImmutableRecordWithoutPickling): unvisited = set(insn.id for insn in self.instructions) + def is_barrier(my_insn_id): + insn = self.id_to_insn[my_insn_id] + from loopy.kernel.instruction import BarrierInstruction + return isinstance(insn, BarrierInstruction) and insn.kind == "global" + while unvisited: stack = [unvisited.pop()] @@ -847,12 +852,8 @@ class LoopKernel(ImmutableRecordWithoutPickling): if top in visiting: visiting.remove(top) - - from loopy.kernel.instruction import BarrierInstruction - insn = self.id_to_insn[top] - if isinstance(insn, BarrierInstruction): - if insn.kind == "global": - barriers.append(top) + if is_barrier(top): + barriers.append(top) if top in visited: stack.pop() @@ -867,11 +868,42 @@ class LoopKernel(ImmutableRecordWithoutPickling): stack.append(child) # Ensure this is the only possible order. + # + # We do this by looking at the barriers in order. + # We check for each adjacent pair (a,b) in the order if a < b, + # i.e. if a is reachable by a chain of dependencies from b. + + visiting.clear() + visited.clear() + for prev_barrier, barrier in zip(barriers, barriers[1:]): - if prev_barrier not in self.recursive_insn_dep_map()[barrier]: - raise LoopyError( - "Unordered global barriers detected: '%s', '%s'" - % (barrier, prev_barrier)) + # Check if prev_barrier is reachable from barrier. + stack = [barrier] + visited.discard(prev_barrier) + + while stack: + top = stack[-1] + + if top in visiting: + visiting.remove(top) + + if top in visited: + stack.pop() + continue + + visited.add(top) + visiting.add(top) + + if top == prev_barrier: + visiting.clear() + break + + for child in self.id_to_insn[top].depends_on: + stack.append(child) + else: + # Search exhausted and we did not find prev_barrier. + raise LoopyError("barriers '%s' and '%s' are not ordered" + % (prev_barrier, barrier)) return tuple(barriers)