diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 5050f487c5bd45d0b865f146aa759953b023181a..2b6d97c38a12b47e5b4653297c18b24c40ed938b 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -455,6 +455,12 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): depends_on=frozenset()) generated_insns.append(init_insn) + def _strip_if_scalar(c): + if len(acc_vars) == 1: + return c[0] + else: + return c + init_neutral_id = insn_id_gen("%s_%s_init_neutral" % (insn.id, red_iname)) init_neutral_insn = make_assignment( id=init_neutral_id, @@ -473,7 +479,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): for acc_var in acc_vars), expression=expr.operation( arg_dtype, - tuple(var(nvn) for nvn in neutral_var_names), + _strip_if_scalar(tuple(var(nvn) for nvn in neutral_var_names)), expr.expr, expr.inames), within_inames=( (outer_insn_inames - frozenset(expr.inames)) @@ -483,12 +489,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): no_sync_with=frozenset([(init_id, "any")])) generated_insns.append(transfer_insn) - def _strip_if_scalar(c): - if len(acc_vars) == 1: - return c[0] - else: - return c - cur_size = 1 while cur_size < size: cur_size *= 2