diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py
index 752e3e4da7908132e1b9ba001451d3b86bd037f9..581f090547370ca1b8cc4752dc70e9408e6ab37c 100644
--- a/loopy/kernel/instruction.py
+++ b/loopy/kernel/instruction.py
@@ -658,11 +658,7 @@ class MultiAssignmentBase(InstructionBase):
     @memoize_method
     def reduction_inames(self):
         def map_reduction(expr, rec):
-            if expr.is_plain_tuple:
-                for sub_expr in expr.exprs:
-                    rec(sub_expr)
-            else:
-                rec(expr.exprs)
+            rec(expr.exprs)
             for iname in expr.inames:
                 result.add(iname)
 
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 6d6494b5e1fa9c671c40f9c8737f9292527c9360..5ece0db1dffd2cde118bc3104b90ce6faa14a448 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -97,11 +97,7 @@ def check_reduction_iname_uniqueness(kernel):
     iname_to_nonsimultaneous_reduction_count = {}
 
     def map_reduction(expr, rec):
-        if expr.is_plain_tuple:
-            for sub_expr in expr.exprs:
-                rec(sub_expr)
-        else:
-            rec(expr.exprs)
+        rec(expr.exprs)
 
         for iname in expr.inames:
             iname_to_reduction_count[iname] = (
@@ -493,6 +489,39 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
         else:
             return val
 
+    def expand_inner_reduction(id, expr, nresults, depends_on, within_inames,
+            within_inames_is_final):
+        from pymbolic.primitives import Call
+        from loopy.symbolic import Reduction
+        assert isinstance(expr, (Call, Reduction))
+
+        temp_var_names = [
+                var_name_gen(id + "_arg" + str(i))
+                for i in range(nresults)]
+
+        for name in temp_var_names:
+            from loopy.kernel.data import TemporaryVariable, temp_var_scope
+            new_temporary_variables[name] = TemporaryVariable(
+                    name=name,
+                    shape=(),
+                    dtype=lp.auto,
+                    scope=temp_var_scope.PRIVATE)
+
+        from pymbolic import var
+        temp_vars = tuple(var(n) for n in temp_var_names)
+
+        call_insn = make_assignment(
+                id=id,
+                assignees=temp_vars,
+                expression=expr,
+                depends_on=depends_on,
+                within_inames=within_inames,
+                within_inames_is_final=within_inames_is_final)
+
+        generated_insns.append(call_insn)
+
+        return temp_vars
+
     # }}}
 
     # {{{ sequential
@@ -536,14 +565,32 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
         if insn.within_inames_is_final:
             update_insn_iname_deps = insn.within_inames | set(expr.inames)
 
+        reduction_insn_depends_on = set([init_id])
+
+        if not isinstance(expr.exprs, tuple):
+            get_args_insn_id = insn_id_gen(
+                    "%s_%s_get" % (insn.id, "_".join(expr.inames)))
+
+            reduction_expr = expand_inner_reduction(
+                id=get_args_insn_id,
+                expr=expr.exprs,
+                nresults=nresults,
+                depends_on=insn.depends_on,
+                within_inames=update_insn_iname_deps,
+                within_inames_is_final=insn.within_inames_is_final)
+
+            reduction_insn_depends_on.add(get_args_insn_id)
+        else:
+            reduction_expr = expr.exprs
+
         reduction_insn = make_assignment(
                 id=update_id,
                 assignees=acc_vars,
                 expression=expr.operation(
                     arg_dtypes,
                     _strip_if_scalar(acc_vars, acc_vars),
-                    _strip_if_scalar(acc_vars, expr.exprs)),
-                depends_on=frozenset([init_insn.id]) | insn.depends_on,
+                    _strip_if_scalar(acc_vars, reduction_expr)),
+                depends_on=frozenset(reduction_insn_depends_on) | insn.depends_on,
                 within_inames=update_insn_iname_deps,
                 within_inames_is_final=insn.within_inames_is_final)
 
@@ -670,6 +717,26 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 depends_on=frozenset())
         generated_insns.append(init_neutral_insn)
 
+        transfer_depends_on = set([init_neutral_id, init_id])
+
+        if not isinstance(expr.exprs, tuple):
+            get_args_insn_id = insn_id_gen(
+                    "%s_%s_get" % (insn.id, red_iname))
+
+            reduction_expr = expand_inner_reduction(
+                    id=get_args_insn_id,
+                    expr=expr.exprs,
+                    nresults=nresults,
+                    depends_on=insn.depends_on,
+                    within_inames=(
+                        (outer_insn_inames - frozenset(expr.inames))
+                        | frozenset([red_iname])),
+                    within_inames_is_final=insn.within_inames_is_final)
+
+            transfer_depends_on.add(get_args_insn_id)
+        else:
+            reduction_expr = expr.exprs
+
         transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname))
         transfer_insn = make_assignment(
                 id=transfer_id,
@@ -679,15 +746,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 expression=expr.operation(
                     arg_dtypes,
                     _strip_if_scalar(
-                        expr.exprs,
+                        neutral_var_names,
                         tuple(var(nvn) for nvn in neutral_var_names)),
-                    _strip_if_scalar(expr.exprs, expr.exprs)),
+                    _strip_if_scalar(neutral_var_names, reduction_expr)),
                 within_inames=(
                     (outer_insn_inames - frozenset(expr.inames))
                     | frozenset([red_iname])),
                 within_inames_is_final=insn.within_inames_is_final,
-                depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on,
-                no_sync_with=frozenset([(init_id, "any")]))
+                depends_on=frozenset(transfer_depends_on) | insn.depends_on,
+                no_sync_with=frozenset(
+                    [(insn_id, "any") for insn_id in transfer_depends_on]))
         generated_insns.append(transfer_insn)
 
         cur_size = 1
@@ -699,7 +767,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         istage = 0
         while cur_size > 1:
-
             new_size = cur_size // 2
             assert new_size * 2 == cur_size
 
@@ -926,6 +993,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
     kernel = lp.tag_inames(kernel, new_iname_tags)
 
+    print(kernel)
+
     kernel = (
             _hackily_ensure_multi_assignment_return_values_are_scoped_private(
                 kernel))
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 8876e295027bdb9e8ee4f0f580f4742d249511ff..89ac05f70a07fd95c88a341801a2d079d9506611 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -96,9 +96,7 @@ class IdentityMapperMixin(object):
 
         return Reduction(
                 expr.operation, tuple(new_inames),
-                (tuple(self.rec(e, *args) for e in expr.exprs)
-                    if expr.is_plain_tuple
-                    else self.rec(expr.exprs, *args)),
+                self.rec(expr.exprs, *args),
                 allow_simultaneous=expr.allow_simultaneous)
 
     def map_tagged_variable(self, expr, *args):
@@ -147,11 +145,7 @@ class WalkMapper(WalkMapperBase):
         if not self.visit(expr):
             return
 
-        if expr.is_plain_tuple:
-            for sub_expr in expr.exprs:
-                self.rec(sub_expr, *args)
-        else:
-            self.rec(expr.exprs, *args)
+        self.rec(expr.exprs, *args)
 
     map_tagged_variable = WalkMapperBase.map_variable
 
@@ -169,10 +163,7 @@ class CallbackMapper(CallbackMapperBase, IdentityMapper):
 
 class CombineMapper(CombineMapperBase):
     def map_reduction(self, expr):
-        if expr.is_plain_tuple:
-            return self.combine(self.rec(sub_expr) for sub_expr in expr.exprs)
-        else:
-            return self.rec(expr.exprs)
+        return self.rec(expr.exprs)
 
     map_linear_subscript = CombineMapperBase.map_subscript
 
@@ -203,12 +194,16 @@ class StringifyMapper(StringifyMapperBase):
 
     def map_reduction(self, expr, prec):
         from pymbolic.mapper.stringifier import PREC_NONE
+
+        if isinstance(expr.exprs, tuple):
+            inner_expr = ", ".join(self.rec(e, PREC_NONE) for e in expr.exprs)
+        else:
+            inner_expr = self.rec(expr.exprs, PREC_NONE)
+
         return "%sreduce(%s, [%s], %s)" % (
                 "simul_" if expr.allow_simultaneous else "",
                 expr.operation, ", ".join(expr.inames),
-                (", ".join(self.rec(e, PREC_NONE) for e in expr.exprs)
-                    if expr.is_plain_tuple
-                    else self.rec(expr.exprs, PREC_NONE)))
+                inner_expr)
 
     def map_tagged_variable(self, expr, prec):
         return "%s$%s" % (expr.name, expr.tag)
@@ -238,15 +233,6 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase):
                 or type(expr.operation) != type(other.operation)  # noqa
                 ):
             return []
-        if expr.is_plain_tuple != other.is_plain_tuple:
-            return []
-
-        if expr.is_plain_tuple:
-            for sub_expr_l, sub_expr_r in zip(expr.exprs, other.exprs):
-                unis = self.rec(sub_expr_l, sub_expr_r, unis)
-                if not unis:
-                    break
-            return unis
 
         return self.rec(expr.exprs, other.exprs, unis)
 
@@ -281,10 +267,7 @@ class DependencyMapper(DependencyMapperBase):
                 self.rec(child, *args) for child in expr.parameters)
 
     def map_reduction(self, expr):
-        if expr.is_plain_tuple:
-            deps = self.combine(self.rec(sub_expr) for sub_expr in expr.exprs)
-        else:
-            deps = self.rec(expr.exprs)
+        deps = self.rec(expr.exprs)
         return deps - set(p.Variable(iname) for iname in expr.inames)
 
     def map_tagged_variable(self, expr):
@@ -503,8 +486,13 @@ class Reduction(p.Expression):
             from loopy.library.reduction import parse_reduction_op
             operation = parse_reduction_op(operation)
 
-        if not isinstance(exprs, tuple):
-            exprs = (exprs,)
+        from pymbolic.primitives import Call
+        if not isinstance(exprs, (tuple, Reduction, Call)):
+            from loopy.diagnostic import LoopyError
+            print(exprs)
+            raise LoopyError(
+                "reduction argument must be a tuple, reduction, or substitution "
+                "invocation, got '%s'" % type(exprs).__name__)
 
         from loopy.library.reduction import ReductionOperation
         assert isinstance(operation, ReductionOperation)
@@ -530,12 +518,11 @@ class Reduction(p.Expression):
         return StringifyMapper
 
     @property
-    def is_plain_tuple(self):
-        """
-        :return: True if the reduction expression is a tuple, False if otherwise
-            (the inner expression will still have a tuple type)
-        """
-        return isinstance(self.exprs, tuple)
+    def exprs_stripped_if_scalar(self):
+        if isinstance(self.exprs, tuple) and len(self.exprs) == 1:
+            return self.exprs[0]
+        else:
+            return self.exprs
 
     @property
     @memoize_method
@@ -1426,10 +1413,7 @@ class IndexVariableFinder(CombineMapper):
         return result
 
     def map_reduction(self, expr):
-        if expr.is_plain_tuple:
-            result = self.combine(self.rec(sub_expr) for sub_expr in expr.exprs)
-        else:
-            result = self.rec(expr.exprs)
+        result = self.rec(expr.exprs)
 
         if not (expr.inames_set & result):
             raise RuntimeError("reduction '%s' does not depend on "
diff --git a/loopy/transform/data.py b/loopy/transform/data.py
index a1948b615cc09bd7b4c50774f14c6fd61364150e..ee5ffb6bcf3cda1971261ff29d4d14eafadd00ff 100644
--- a/loopy/transform/data.py
+++ b/loopy/transform/data.py
@@ -684,7 +684,8 @@ def set_temporary_scope(kernel, temp_var_names, scope):
 # {{{ reduction_arg_to_subst_rule
 
 def reduction_arg_to_subst_rule(
-        knl, inames, insn_match=None, subst_rule_name=None, arg_number=0):
+        knl, inames, insn_match=None, subst_rule_name=None,
+        strip_if_scalar=False):
     if isinstance(inames, str):
         inames = [s.strip() for s in inames.split(",")]
 
@@ -696,10 +697,7 @@ def reduction_arg_to_subst_rule(
 
     def map_reduction(expr, rec, nresults=1):
         if frozenset(expr.inames) != inames_set:
-            if expr.is_plain_tuple:
-                rec_result = tuple(rec(sub_expr) for sub_expr in expr.exprs)
-            else:
-                rec_result = rec(expr.exprs)
+            rec_result = rec(expr.exprs)
 
             return type(expr)(
                     operation=expr.operation,
@@ -717,27 +715,22 @@ def reduction_arg_to_subst_rule(
             raise LoopyError("substitution rule '%s' already exists"
                     % my_subst_rule_name)
 
-        if not expr.is_plain_tuple:
-            raise NotImplemented("non-tuple reduction arguments not supported")
-
         from loopy.kernel.data import SubstitutionRule
         substs[my_subst_rule_name] = SubstitutionRule(
                 name=my_subst_rule_name,
                 arguments=tuple(inames),
-                expression=expr.exprs[arg_number])
+                expression=(
+                    expr.exprs_stripped_if_scalar
+                    if strip_if_scalar
+                    else expr.exprs))
 
         from pymbolic import var
         iname_vars = [var(iname) for iname in inames]
 
-        new_exprs = tuple(sub_expr
-                if i != arg_number
-                else var(my_subst_rule_name)(*iname_vars)
-                for i, sub_expr in enumerate(expr.exprs))
-
         return type(expr)(
                 operation=expr.operation,
                 inames=expr.inames,
-                exprs=new_exprs,
+                exprs=var(my_subst_rule_name)(*iname_vars),
                 allow_simultaneous=expr.allow_simultaneous)
 
     from loopy.symbolic import ReductionCallbackMapper
diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py
index 81db51a7e6f3aa6bcef1e804325c7628e32ae095..b9a386b2b69ab1c3136f5f91075bc0129e320748 100644
--- a/loopy/transform/iname.py
+++ b/loopy/transform/iname.py
@@ -145,10 +145,7 @@ class _InameSplitter(RuleAwareIdentityMapper):
 
             from loopy.symbolic import Reduction
             return Reduction(expr.operation, tuple(new_inames),
-                        (tuple(self.rec(sub_expr, expn_state)
-                               for sub_expr in expr.exprs)
-                         if expr.is_plain_tuple
-                         else self.rec(expr.exprs, expn_state)),
+                        self.rec(expr.exprs, expn_state),
                         expr.allow_simultaneous)
         else:
             return super(_InameSplitter, self).map_reduction(expr, expn_state)
@@ -1194,20 +1191,15 @@ class _ReductionSplitter(RuleAwareIdentityMapper):
             if self.direction == "in":
                 return Reduction(expr.operation, tuple(leftover_inames),
                         Reduction(expr.operation, tuple(self.inames),
-                            (tuple(self.rec(sub_expr, expn_state)
-                                  for sub_expr in expr.exprs)
-                             if expr.is_plain_tuple
-                             else self.rec(expr.exprs, expn_state)),
+                            self.rec(expr.exprs, expn_state),
                             expr.allow_simultaneous),
                         expr.allow_simultaneous)
             elif self.direction == "out":
                 return Reduction(expr.operation, tuple(self.inames),
                         Reduction(expr.operation, tuple(leftover_inames),
-                            (tuple(self.rec(sub_expr, expn_state)
-                                  for sub_expr in expr.exprs)
-                             if expr.is_plain_tuple
-                             else self.rec(expr.exprs, expn_state)),
-                            expr.allow_simultaneous))
+                            self.rec(expr.exprs, expn_state),
+                            expr.allow_simultaneous),
+                        expr.allow_simultaneous)
             else:
                 assert False
         else:
@@ -1598,16 +1590,9 @@ class _ReductionInameUniquifier(RuleAwareIdentityMapper):
 
             from loopy.symbolic import Reduction
             return Reduction(expr.operation, tuple(new_inames),
-                    (tuple(self.rec(
-                            SubstitutionMapper(make_subst_func(subst_dict))(
-                                sub_expr),
-                            expn_state)
-                        for sub_expr in expr.exprs)
-                     if expr.is_plain_tuple
-                     else self.rec(
-                             SubstitutionMapper(make_subst_func(subst_dict))(
-                                 expr.exprs),
-                             expn_state)),
+                    self.rec(
+                        SubstitutionMapper(make_subst_func(subst_dict))(expr.exprs),
+                        expn_state),
                     expr.allow_simultaneous)
         else:
             return super(_ReductionInameUniquifier, self).map_reduction(
diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py
index a19e06ecdf7c9966501ebb9600ea4e01614363f4..7e70f8c77547d39e6402d05fe56ca5dfd8a1fc64 100644
--- a/loopy/transform/precompute.py
+++ b/loopy/transform/precompute.py
@@ -59,9 +59,33 @@ def storage_axis_exprs(storage_axis_sources, args):
     return result
 
 
+# {{{ identity mapper
+
+class PrecomputeIdentityMapper(RuleAwareIdentityMapper):
+
+    def map_reduction(self, expr, expn_state):
+        from pymbolic.primitives import Call
+        new_exprs = self.rec(expr.exprs, expn_state)
+
+        # If the substitution rule was replaced, precompute turned it into a
+        # scalar, but since reduction only takes tuple types we turn it into a
+        # tuple here.
+        if isinstance(expr.exprs, Call) and not isinstance(new_exprs, Call):
+            new_exprs = (new_exprs,)
+
+        from loopy.symbolic import Reduction
+        return Reduction(
+                expr.operation,
+                expr.inames,
+                new_exprs,
+                expr.allow_simultaneous)
+
+# }}}
+
+
 # {{{ gather rule invocations
 
-class RuleInvocationGatherer(RuleAwareIdentityMapper):
+class RuleInvocationGatherer(PrecomputeIdentityMapper):
     def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within):
         super(RuleInvocationGatherer, self).__init__(rule_mapping_context)
 
@@ -131,7 +155,7 @@ class RuleInvocationGatherer(RuleAwareIdentityMapper):
 
 # {{{ replace rule invocation
 
-class RuleInvocationReplacer(RuleAwareIdentityMapper):
+class RuleInvocationReplacer(PrecomputeIdentityMapper):
     def __init__(self, rule_mapping_context, subst_name, subst_tag, within,
             access_descriptors, array_base_map,
             storage_axis_names, storage_axis_sources,
diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index 3c77c988261b63334f3cb8f0f84e2ea69c87901b..b6aa5d1ad055b316d68ee51e947e648df499d582 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -357,10 +357,18 @@ class TypeInferenceMapper(CombineMapper):
             as a tuple type. Otherwise, the number of expressions being reduced over
             must equal 1, and the type of the first expression is returned.
         """
-        if expr.is_plain_tuple:
+        from loopy.symbolic import Reduction
+        from pymbolic.primitives import Call
+
+        if isinstance(expr.exprs, tuple):
             rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs]
+        elif isinstance(expr.exprs, Reduction):
+            rec_results = [self.rec(expr.exprs, return_tuple=True)]
+        elif isinstance(expr.exprs, Call):
+            rec_results = [self.map_call(expr.exprs, return_tuple=return_tuple)]
         else:
-            rec_results = [self.rec(expr.exprs, return_tuple=return_tuple)]
+            raise LoopyError("unknown reduction type: '%s'"
+                             % type(expr.exprs).__name__)
 
         if any(len(rec_result) == 0 for rec_result in rec_results):
             return []
@@ -629,7 +637,12 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression(
     type_inf_mapper = TypeInferenceMapper(kernel)
     import loopy as lp
 
-    for sub_expr in expr.exprs:
+    if isinstance(expr.exprs, tuple):
+        exprs = expr.exprs
+    else:
+        exprs = (expr.exprs,)
+
+    for sub_expr in exprs:
         try:
             arg_dtype = type_inf_mapper(sub_expr)
         except DependencyTypeInferenceFailure:
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 851a7f0762fcec3ccbb55399e183f5fb51322ac1..d5d1a1f31ba5ad9ecaeedeb92b1188d5208e37c6 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2108,6 +2108,28 @@ def test_barrier_insertion_near_bottom_of_loop():
     assert_barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1])
 
 
+def test_multi_argument_reduction_type_inference():
+    from loopy.type_inference import TypeInferenceMapper
+    from loopy.library.reduction import SegmentedSumReductionOperation
+    from loopy.types import to_loopy_type
+    op = SegmentedSumReductionOperation()
+
+    knl = lp.make_kernel("{[i]: 0<=i<10}", "")
+
+    int32 = to_loopy_type(np.int32)
+
+    expr = lp.symbolic.Reduction(
+            operation=op,
+            inames=("i",),
+            exprs=op.neutral_element(int32, int32),
+            allow_simultaneous=True)
+
+    t_inf_mapper = TypeInferenceMapper(knl)
+
+    print(t_inf_mapper(expr, return_tuple=True))
+    1/0
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])