diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index f60662707962100cdba58a63b0295d215a76ffe5..c13fd295baaaf93c306466d84b946a5275a6103a 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1038,7 +1038,13 @@ class CSEToAssignmentMapper(IdentityMapper): self.add_assignment = add_assignment self.expr_to_var = {} - def map_common_subexpression(self, expr): + def map_reduction(self, expr, additional_inames): + additional_inames = additional_inames | frozenset(expr.inames) + + return super(CSEToAssignmentMapper, self).map_reduction( + expr, additional_inames) + + def map_common_subexpression(self, expr, additional_inames): try: return self.expr_to_var[expr.child] except KeyError: @@ -1048,19 +1054,20 @@ class CSEToAssignmentMapper(IdentityMapper): else: dtype = None - child = self.rec(expr.child) + child = self.rec(expr.child, additional_inames) from pymbolic.primitives import Variable if isinstance(child, Variable): return child - var_name = self.add_assignment(expr.prefix, child, dtype) + var_name = self.add_assignment( + expr.prefix, child, dtype, additional_inames) var = Variable(var_name) self.expr_to_var[expr.child] = var return var def expand_cses(instructions, inames_to_dup, cse_prefix="cse_expr"): - def add_assignment(base_name, expr, dtype): + def add_assignment(base_name, expr, dtype, additional_inames): if base_name is None: base_name = "var" @@ -1085,11 +1092,17 @@ def expand_cses(instructions, inames_to_dup, cse_prefix="cse_expr"): assignee=Variable(new_var_name), expression=expr, predicates=insn.predicates, - within_inames=insn.within_inames, + within_inames=insn.within_inames | additional_inames, within_inames_is_final=insn.within_inames_is_final, ) newly_created_insn_ids.add(new_insn.id) new_insns.append(new_insn) + if insn_inames_to_dup: + raise LoopyError("in-line iname duplication not allowed in " + "an instruction containing a tagged common " + "subexpression (found in instruction '%s')" + % insn) + new_inames_to_dup.append(insn_inames_to_dup) return new_var_name @@ -1107,7 +1120,8 @@ def expand_cses(instructions, inames_to_dup, cse_prefix="cse_expr"): for insn, insn_inames_to_dup in zip(instructions, inames_to_dup): if isinstance(insn, MultiAssignmentBase): - new_insns.append(insn.copy(expression=cseam(insn.expression))) + new_insns.append(insn.copy( + expression=cseam(insn.expression, frozenset()))) new_inames_to_dup.append(insn_inames_to_dup) else: new_insns.append(insn)