diff --git a/loopy/transform/instruction.py b/loopy/transform/instruction.py index 9a7936cd3b652e29938dc89c2eaaa04e42c35b67..bdf74fc56e8ee3821311abf2f916aa45c34b3aa5 100644 --- a/loopy/transform/instruction.py +++ b/loopy/transform/instruction.py @@ -250,40 +250,42 @@ def remove_instructions(kernel, insn_ids): # {{{ replace_instruction_ids -def replace_instruction_ids(kernel, replacements): - if not replacements: - return kernel +def replace_instruction_ids_in_insn(insn, replacements): + changed = False + new_depends_on = list(insn.depends_on) + extra_depends_on = [] + new_no_sync_with = [] + + for idep, dep in enumerate(insn.depends_on): + if dep in replacements: + new_deps = list(replacements[dep]) + new_depends_on[idep] = new_deps[0] + extra_depends_on.extend(new_deps[1:]) + changed = True - new_insns = [] + for insn_id, scope in insn.no_sync_with: + if insn_id in replacements: + new_no_sync_with.extend( + (repl, scope) for repl in replacements[insn_id]) + changed = True + else: + new_no_sync_with.append((insn_id, scope)) - for insn in kernel.instructions: - changed = False - new_depends_on = list(insn.depends_on) - extra_depends_on = [] - new_no_sync_with = [] - - for idep, dep in enumerate(insn.depends_on): - if dep in replacements: - new_deps = list(replacements[dep]) - new_depends_on[idep] = new_deps[0] - extra_depends_on.extend(new_deps[1:]) - changed = True - - for insn_id, scope in insn.no_sync_with: - if insn_id in replacements: - new_no_sync_with.extend( - (repl, scope) for repl in replacements[insn_id]) - changed = True - else: - new_no_sync_with.append((insn_id, scope)) + if changed: + return insn.copy( + depends_on=frozenset(new_depends_on + extra_depends_on), + no_sync_with=frozenset(new_no_sync_with)) + else: + return insn - new_insns.append( - insn.copy( - depends_on=frozenset(new_depends_on + extra_depends_on), - no_sync_with=frozenset(new_no_sync_with)) - if changed else insn) - return kernel.copy(instructions=new_insns) +def replace_instruction_ids(kernel, replacements): + if not replacements: + return kernel + + return kernel.copy(instructions=[ + replace_instruction_ids_in_insn(insn, replacements) + for insn in kernel.instructions]) # }}}