diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py
index 85c5019293c6aa79ad853cb938cbe5fe5267a351..0d22dbb88ed99c7c92480d1d39b924cc2198cc3f 100644
--- a/loopy/kernel/instruction.py
+++ b/loopy/kernel/instruction.py
@@ -664,7 +664,7 @@ class MultiAssignmentBase(InstructionBase):
     @memoize_method
     def reduction_inames(self):
         def map_reduction(expr, rec):
-            rec(expr.exprs)
+            rec(expr.expr)
             for iname in expr.inames:
                 result.add(iname)
 
diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index e3c7e6099b4fbce49a3fd3a976740ddcba3a41cb..7037de99464858af858ea7c9d2e0c17c70311e7d 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -65,10 +65,6 @@ class ReductionOperation(object):
     def __ne__(self, other):
         return not self.__eq__(other)
 
-    @property
-    def is_segmented(self):
-        raise NotImplementedError
-
     @staticmethod
     def parse_result_type(target, op_type):
         try:
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 5f62d1a9d3eb40b4a5e9ac29212916b33b04d844..17226b63addb9e2e30d556730aa326d2ed59128c 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -97,7 +97,7 @@ def check_reduction_iname_uniqueness(kernel):
     iname_to_nonsimultaneous_reduction_count = {}
 
     def map_reduction(expr, rec):
-        rec(expr.exprs)
+        rec(expr.expr)
 
         for iname in expr.inames:
             iname_to_reduction_count[iname] = (
@@ -567,13 +567,13 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         reduction_insn_depends_on = set([init_id])
 
-        if not isinstance(expr.exprs, tuple):
+        if nresults > 1 and not isinstance(expr.expr, tuple):
             get_args_insn_id = insn_id_gen(
                     "%s_%s_get" % (insn.id, "_".join(expr.inames)))
 
             reduction_expr = expand_inner_reduction(
                 id=get_args_insn_id,
-                expr=expr.exprs,
+                expr=expr.expr,
                 nresults=nresults,
                 depends_on=insn.depends_on,
                 within_inames=update_insn_iname_deps,
@@ -581,7 +581,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
             reduction_insn_depends_on.add(get_args_insn_id)
         else:
-            reduction_expr = expr.exprs
+            reduction_expr = expr.expr
 
         reduction_insn = make_assignment(
                 id=update_id,
@@ -589,7 +589,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 expression=expr.operation(
                     arg_dtypes,
                     _strip_if_scalar(acc_vars, acc_vars),
-                    _strip_if_scalar(acc_vars, reduction_expr)),
+                    reduction_expr),
                 depends_on=frozenset(reduction_insn_depends_on) | insn.depends_on,
                 within_inames=update_insn_iname_deps,
                 within_inames_is_final=insn.within_inames_is_final)
@@ -626,14 +626,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 v[iname].lt_set(v[0] + size)).get_basic_sets()
         return bs
 
-    def _make_slab_set_from_range(iname, lbound, ubound):
-        v = isl.make_zero_and_vars([iname])
-        bs, = (
-                v[iname].ge_set(v[0] + lbound)
-                &
-                v[iname].lt_set(v[0] + ubound)).get_basic_sets()
-        return bs
-
     def map_reduction_local(expr, rec, nresults, arg_dtypes,
             reduction_dtypes):
         red_iname, = expr.inames
@@ -719,13 +711,13 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         transfer_depends_on = set([init_neutral_id, init_id])
 
-        if not isinstance(expr.exprs, tuple):
+        if nresults > 1 and not isinstance(expr.expr, tuple):
             get_args_insn_id = insn_id_gen(
                     "%s_%s_get" % (insn.id, red_iname))
 
             reduction_expr = expand_inner_reduction(
                     id=get_args_insn_id,
-                    expr=expr.exprs,
+                    expr=expr.expr,
                     nresults=nresults,
                     depends_on=insn.depends_on,
                     within_inames=(
@@ -735,7 +727,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
             transfer_depends_on.add(get_args_insn_id)
         else:
-            reduction_expr = expr.exprs
+            reduction_expr = expr.expr
 
         transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname))
         transfer_insn = make_assignment(
@@ -748,7 +740,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                     _strip_if_scalar(
                         neutral_var_names,
                         tuple(var(nvn) for nvn in neutral_var_names)),
-                    _strip_if_scalar(neutral_var_names, reduction_expr)),
+                    reduction_expr),
                 within_inames=(
                     (outer_insn_inames - frozenset(expr.inames))
                     | frozenset([red_iname])),
@@ -993,8 +985,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
     kernel = lp.tag_inames(kernel, new_iname_tags)
 
-    print(kernel)
-
     kernel = (
             _hackily_ensure_multi_assignment_return_values_are_scoped_private(
                 kernel))
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 89ac05f70a07fd95c88a341801a2d079d9506611..3a462316654fa947026b5f177b96815b5b05ffab 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -96,7 +96,7 @@ class IdentityMapperMixin(object):
 
         return Reduction(
                 expr.operation, tuple(new_inames),
-                self.rec(expr.exprs, *args),
+                self.rec(expr.expr, *args),
                 allow_simultaneous=expr.allow_simultaneous)
 
     def map_tagged_variable(self, expr, *args):
@@ -145,7 +145,7 @@ class WalkMapper(WalkMapperBase):
         if not self.visit(expr):
             return
 
-        self.rec(expr.exprs, *args)
+        self.rec(expr.expr, *args)
 
     map_tagged_variable = WalkMapperBase.map_variable
 
@@ -163,7 +163,7 @@ class CallbackMapper(CallbackMapperBase, IdentityMapper):
 
 class CombineMapper(CombineMapperBase):
     def map_reduction(self, expr):
-        return self.rec(expr.exprs)
+        return self.rec(expr.expr)
 
     map_linear_subscript = CombineMapperBase.map_subscript
 
@@ -195,15 +195,10 @@ class StringifyMapper(StringifyMapperBase):
     def map_reduction(self, expr, prec):
         from pymbolic.mapper.stringifier import PREC_NONE
 
-        if isinstance(expr.exprs, tuple):
-            inner_expr = ", ".join(self.rec(e, PREC_NONE) for e in expr.exprs)
-        else:
-            inner_expr = self.rec(expr.exprs, PREC_NONE)
-
         return "%sreduce(%s, [%s], %s)" % (
                 "simul_" if expr.allow_simultaneous else "",
                 expr.operation, ", ".join(expr.inames),
-                inner_expr)
+                self.rec(expr.expr, PREC_NONE))
 
     def map_tagged_variable(self, expr, prec):
         return "%s$%s" % (expr.name, expr.tag)
@@ -234,7 +229,7 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase):
                 ):
             return []
 
-        return self.rec(expr.exprs, other.exprs, unis)
+        return self.rec(expr.expr, other.expr, unis)
 
     def map_tagged_variable(self, expr, other, urecs):
         new_uni_record = self.unification_record_from_equation(
@@ -267,7 +262,7 @@ class DependencyMapper(DependencyMapperBase):
                 self.rec(child, *args) for child in expr.parameters)
 
     def map_reduction(self, expr):
-        deps = self.rec(expr.exprs)
+        deps = self.rec(expr.expr)
         return deps - set(p.Variable(iname) for iname in expr.inames)
 
     def map_tagged_variable(self, expr):
@@ -449,9 +444,10 @@ class Reduction(p.Expression):
         a list of inames across which reduction on :attr:`expr` is being
         carried out.
 
-    .. attribute:: exprs
+    .. attribute:: expr
 
-        A (tuple-typed) expression which currently may be one of
+        An expression which may have tuple type. If the expression has tuple
+        type, it must be one of the following:
          * a :class:`tuple` of :class:`pymbolic.primitives.Expression`, or
          * a :class:`loopy.symbolic.Reduction`, or
          * a substitution rule invocation.
@@ -462,9 +458,9 @@ class Reduction(p.Expression):
         in precisely one reduction, to avoid mis-nesting errors.
     """
 
-    init_arg_names = ("operation", "inames", "exprs", "allow_simultaneous")
+    init_arg_names = ("operation", "inames", "expr", "allow_simultaneous")
 
-    def __init__(self, operation, inames, exprs, allow_simultaneous=False):
+    def __init__(self, operation, inames, expr, allow_simultaneous=False):
         if isinstance(inames, str):
             inames = tuple(iname.strip() for iname in inames.split(","))
 
@@ -486,43 +482,48 @@ class Reduction(p.Expression):
             from loopy.library.reduction import parse_reduction_op
             operation = parse_reduction_op(operation)
 
-        from pymbolic.primitives import Call
-        if not isinstance(exprs, (tuple, Reduction, Call)):
-            from loopy.diagnostic import LoopyError
-            print(exprs)
-            raise LoopyError(
-                "reduction argument must be a tuple, reduction, or substitution "
-                "invocation, got '%s'" % type(exprs).__name__)
-
         from loopy.library.reduction import ReductionOperation
         assert isinstance(operation, ReductionOperation)
 
+        from loopy.diagnostic import LoopyError
+
+        if operation.arg_count > 1:
+            from pymbolic.primitives import Call
+
+            if not isinstance(expr, (tuple, Reduction, Call)):
+                raise LoopyError("reduction argument must be one of "
+                                 "a tuple, reduction, or substitution rule "
+                                 "invocation, got '%s'" % type(expr).__name__)
+        else:
+            # Sanity checks
+            if isinstance(expr, tuple):
+                raise LoopyError("got a tuple argument to a scalar reduction")
+            elif isinstance(expr, Reduction) and expr.is_tuple_typed:
+                raise LoopyError("got a tuple typed argument to a scalar reduction")
+
         self.operation = operation
         self.inames = inames
-        self.exprs = exprs
+        self.expr = expr
         self.allow_simultaneous = allow_simultaneous
 
     def __getinitargs__(self):
-        return (self.operation, self.inames, self.exprs, self.allow_simultaneous)
+        return (self.operation, self.inames, self.expr, self.allow_simultaneous)
 
     def get_hash(self):
-        return hash((self.__class__, self.operation, self.inames, self.exprs))
+        return hash((self.__class__, self.operation, self.inames, self.expr))
 
     def is_equal(self, other):
         return (other.__class__ == self.__class__
                 and other.operation == self.operation
                 and other.inames == self.inames
-                and other.exprs == self.exprs)
+                and other.expr == self.expr)
 
     def stringifier(self):
         return StringifyMapper
 
     @property
-    def exprs_stripped_if_scalar(self):
-        if isinstance(self.exprs, tuple) and len(self.exprs) == 1:
-            return self.exprs[0]
-        else:
-            return self.exprs
+    def is_tuple_typed(self):
+        return self.operation.arg_count > 1
 
     @property
     @memoize_method
@@ -966,6 +967,11 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 
             processed_inames.append(iname.name)
 
+        if len(red_exprs) == 1:
+            red_exprs = red_exprs[0]
+
+        print("RED EXPRS ARE", red_exprs)
+
         return Reduction(operation, tuple(processed_inames), red_exprs,
                 allow_simultaneous=allow_simultaneous)
 
@@ -991,6 +997,8 @@ class FunctionToPrimitiveMapper(IdentityMapper):
                 raise TypeError("cse takes two arguments")
 
         elif name in ["reduce", "simul_reduce"]:
+
+
             if len(expr.parameters) >= 3:
                 operation, inames = expr.parameters[:2]
                 red_exprs = expr.parameters[2:]
@@ -1413,7 +1421,7 @@ class IndexVariableFinder(CombineMapper):
         return result
 
     def map_reduction(self, expr):
-        result = self.rec(expr.exprs)
+        result = self.rec(expr.expr)
 
         if not (expr.inames_set & result):
             raise RuntimeError("reduction '%s' does not depend on "
diff --git a/loopy/transform/data.py b/loopy/transform/data.py
index ee5ffb6bcf3cda1971261ff29d4d14eafadd00ff..575311b11716f5a52e4713aa51922eb348c839d9 100644
--- a/loopy/transform/data.py
+++ b/loopy/transform/data.py
@@ -683,9 +683,7 @@ def set_temporary_scope(kernel, temp_var_names, scope):
 
 # {{{ reduction_arg_to_subst_rule
 
-def reduction_arg_to_subst_rule(
-        knl, inames, insn_match=None, subst_rule_name=None,
-        strip_if_scalar=False):
+def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=None):
     if isinstance(inames, str):
         inames = [s.strip() for s in inames.split(",")]
 
@@ -697,12 +695,10 @@ def reduction_arg_to_subst_rule(
 
     def map_reduction(expr, rec, nresults=1):
         if frozenset(expr.inames) != inames_set:
-            rec_result = rec(expr.exprs)
-
             return type(expr)(
                     operation=expr.operation,
                     inames=expr.inames,
-                    exprs=rec_result,
+                    expr=rec(expr.expr),
                     allow_simultaneous=expr.allow_simultaneous)
 
         if subst_rule_name is None:
@@ -719,10 +715,7 @@ def reduction_arg_to_subst_rule(
         substs[my_subst_rule_name] = SubstitutionRule(
                 name=my_subst_rule_name,
                 arguments=tuple(inames),
-                expression=(
-                    expr.exprs_stripped_if_scalar
-                    if strip_if_scalar
-                    else expr.exprs))
+                expression=expr.expr)
 
         from pymbolic import var
         iname_vars = [var(iname) for iname in inames]
@@ -730,7 +723,7 @@ def reduction_arg_to_subst_rule(
         return type(expr)(
                 operation=expr.operation,
                 inames=expr.inames,
-                exprs=var(my_subst_rule_name)(*iname_vars),
+                expr=var(my_subst_rule_name)(*iname_vars),
                 allow_simultaneous=expr.allow_simultaneous)
 
     from loopy.symbolic import ReductionCallbackMapper
diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py
index 35c12700806ccf4c0cb56f0ac458d98631cfdc19..ea90abfe27c8de69daf39021b3d0ea5463a2e4c8 100644
--- a/loopy/transform/iname.py
+++ b/loopy/transform/iname.py
@@ -145,7 +145,7 @@ class _InameSplitter(RuleAwareIdentityMapper):
 
             from loopy.symbolic import Reduction
             return Reduction(expr.operation, tuple(new_inames),
-                        self.rec(expr.exprs, expn_state),
+                        self.rec(expr.expr, expn_state),
                         expr.allow_simultaneous)
         else:
             return super(_InameSplitter, self).map_reduction(expr, expn_state)
@@ -1192,13 +1192,13 @@ class _ReductionSplitter(RuleAwareIdentityMapper):
             if self.direction == "in":
                 return Reduction(expr.operation, tuple(leftover_inames),
                         Reduction(expr.operation, tuple(self.inames),
-                            self.rec(expr.exprs, expn_state),
+                            self.rec(expr.expr, expn_state),
                             expr.allow_simultaneous),
                         expr.allow_simultaneous)
             elif self.direction == "out":
                 return Reduction(expr.operation, tuple(self.inames),
                         Reduction(expr.operation, tuple(leftover_inames),
-                            self.rec(expr.exprs, expn_state),
+                            self.rec(expr.expr, expn_state),
                             expr.allow_simultaneous),
                         expr.allow_simultaneous)
             else:
@@ -1592,7 +1592,8 @@ class _ReductionInameUniquifier(RuleAwareIdentityMapper):
             from loopy.symbolic import Reduction
             return Reduction(expr.operation, tuple(new_inames),
                     self.rec(
-                        SubstitutionMapper(make_subst_func(subst_dict))(expr.exprs),
+                        SubstitutionMapper(make_subst_func(subst_dict))(
+                            expr.expr),
                         expn_state),
                     expr.allow_simultaneous)
         else:
diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py
index 7e70f8c77547d39e6402d05fe56ca5dfd8a1fc64..a19e06ecdf7c9966501ebb9600ea4e01614363f4 100644
--- a/loopy/transform/precompute.py
+++ b/loopy/transform/precompute.py
@@ -59,33 +59,9 @@ def storage_axis_exprs(storage_axis_sources, args):
     return result
 
 
-# {{{ identity mapper
-
-class PrecomputeIdentityMapper(RuleAwareIdentityMapper):
-
-    def map_reduction(self, expr, expn_state):
-        from pymbolic.primitives import Call
-        new_exprs = self.rec(expr.exprs, expn_state)
-
-        # If the substitution rule was replaced, precompute turned it into a
-        # scalar, but since reduction only takes tuple types we turn it into a
-        # tuple here.
-        if isinstance(expr.exprs, Call) and not isinstance(new_exprs, Call):
-            new_exprs = (new_exprs,)
-
-        from loopy.symbolic import Reduction
-        return Reduction(
-                expr.operation,
-                expr.inames,
-                new_exprs,
-                expr.allow_simultaneous)
-
-# }}}
-
-
 # {{{ gather rule invocations
 
-class RuleInvocationGatherer(PrecomputeIdentityMapper):
+class RuleInvocationGatherer(RuleAwareIdentityMapper):
     def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within):
         super(RuleInvocationGatherer, self).__init__(rule_mapping_context)
 
@@ -155,7 +131,7 @@ class RuleInvocationGatherer(PrecomputeIdentityMapper):
 
 # {{{ replace rule invocation
 
-class RuleInvocationReplacer(PrecomputeIdentityMapper):
+class RuleInvocationReplacer(RuleAwareIdentityMapper):
     def __init__(self, rule_mapping_context, subst_name, subst_tag, within,
             access_descriptors, array_base_map,
             storage_axis_names, storage_axis_sources,
diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index 34e41740e83ea080728cb8be1671a303c6d2456f..34d3fc5e24f37bb2fb9410023398fc1094069090 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -353,38 +353,38 @@ class TypeInferenceMapper(CombineMapper):
 
     def map_reduction(self, expr, return_tuple=False):
         """
-        :arg return_tuple: If *True*, treat the type of the reduction expression
-            as a tuple type. Otherwise, the number of expressions being reduced over
-            must equal 1, and the type of the first expression is returned.
+        :arg return_tuple: If *True*, treat the reduction as having tuple type.
+        Otherwise, if *False*, the reduction must have scalar type.
         """
         from loopy.symbolic import Reduction
         from pymbolic.primitives import Call
 
-        if isinstance(expr.exprs, tuple):
-            rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs]
+        if not return_tuple and expr.is_tuple_typed:
+            raise LoopyError("reductions with more or fewer than one "
+                             "return value may only be used in direct "
+                             "assignments")
+
+        if isinstance(expr.expr, tuple):
+            rec_results = [self.rec(sub_expr) for sub_expr in expr.expr]
+            from itertools import product
+            rec_results = product(*rec_results)
+        elif isinstance(expr.expr, Reduction):
+            rec_results = self.rec(expr.expr, return_tuple=return_tuple)
+        elif isinstance(expr.expr, Call):
+            rec_results = self.map_call(expr.expr, return_tuple=return_tuple)
+        else:
             if return_tuple:
-                from itertools import product
-                rec_results = product(*rec_results)
+                raise LoopyError("unknown reduction type for tuple reduction: '%s'"
+                        % type(expr.expr).__name__)
             else:
-                rec_results = rec_results[0]
-        elif isinstance(expr.exprs, Reduction):
-            rec_results = self.rec(expr.exprs, return_tuple=return_tuple)
-        elif isinstance(expr.exprs, Call):
-            rec_results = self.map_call(expr.exprs, return_tuple=return_tuple)
+                rec_results = self.rec(expr.expr)
+
+        if return_tuple:
+            return [expr.operation.result_dtypes(self.kernel, *rec_result)
+                    for rec_result in rec_results]
         else:
-            raise LoopyError("unknown reduction type: '%s'"
-                             % type(expr.exprs).__name__)
-
-        if not return_tuple:
-            if any(isinstance(rec_result, tuple) for rec_result in rec_results):
-                raise LoopyError("reductions with more or fewer than one "
-                                 "return value may only be used in direct "
-                                 "assignments")
             return [expr.operation.result_dtypes(self.kernel, rec_result)[0]
-                for rec_result in rec_results]
-
-        return [expr.operation.result_dtypes(self.kernel, *rec_result)
-            for rec_result in rec_results]
+                    for rec_result in rec_results]
 
 # }}}
 
@@ -633,29 +633,29 @@ def infer_unknown_types(kernel, expect_completion=False):
 
 def infer_arg_and_reduction_dtypes_for_reduction_expression(
         kernel, expr, unknown_types_ok):
-    arg_dtypes = []
-
     type_inf_mapper = TypeInferenceMapper(kernel)
     import loopy as lp
 
-    if isinstance(expr.exprs, tuple):
-        exprs = expr.exprs
-    else:
-        exprs = (expr.exprs,)
+    if expr.is_tuple_typed:
+        arg_dtypes_result = type_inf_mapper(expr, return_tuple=True, return_dtype_set=True)
 
-    for sub_expr in exprs:
+        if len(arg_dtypes_result) == 1:
+            arg_dtypes = arg_dtypes_result[0]
+        else:
+            if unknown_types_ok:
+                arg_dtypes = [lp.auto] * expr.operation.arg_count
+            else:
+                raise LoopyError("failed to determine types of accumulators for "
+                        "reduction '%s'" % expr)
+    else:
         try:
-            arg_dtype = type_inf_mapper(sub_expr)
+            arg_dtypes = [type_inf_mapper(expr)]
         except DependencyTypeInferenceFailure:
             if unknown_types_ok:
-                arg_dtype = lp.auto
+                arg_dtypes = [lp.auto]
             else:
                 raise LoopyError("failed to determine type of accumulator for "
-                        "reduction sub-expression '%s'" % sub_expr)
-        else:
-            arg_dtype = arg_dtype.with_target(kernel.target)
-
-        arg_dtypes.append(arg_dtype)
+                        "reduction '%s'" % expr)
 
     reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes)
     reduction_dtypes = tuple(
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 7719607833872127aa1878fbe735d73da1e48bac..b535ec6feca6fa8a48070de7cd7815d83305a995 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1026,7 +1026,7 @@ def test_within_inames_and_reduction():
 
     from pymbolic.primitives import Subscript, Variable
     i2 = lp.Assignment("a",
-            lp.Reduction("sum", "j", (Subscript(Variable("phi"), Variable("j")),)),
+            lp.Reduction("sum", "j", Subscript(Variable("phi"), Variable("j"))),
             within_inames=frozenset(),
             within_inames_is_final=True)
 
@@ -2123,14 +2123,18 @@ def test_multi_argument_reduction_type_inference():
     from loopy.types import to_loopy_type
     op = SegmentedSumReductionOperation()
 
-    knl = lp.make_kernel("{[i]: 0<=i<10}", "")
+    knl = lp.make_kernel("{[i,j]: 0<=i<10 and 0<=j<i}", "")
 
     int32 = to_loopy_type(np.int32)
 
     expr = lp.symbolic.Reduction(
             operation=op,
             inames=("i",),
-            exprs=op.neutral_element(int32, int32),
+            expr=lp.symbolic.Reduction(
+                operation=op,
+                inames="j",
+                expr=(1, 2),
+                allow_simultaneous=True),
             allow_simultaneous=True)
 
     t_inf_mapper = TypeInferenceMapper(knl)
@@ -2140,6 +2144,14 @@ def test_multi_argument_reduction_type_inference():
             == [(int32, int32)])
 
 
+def test_multi_argument_reduction_parsing():
+    from loopy.symbolic import parse, Reduction
+
+    assert isinstance(
+            parse("reduce(argmax, i, reduce(argmax, j, i, j))").expr,
+            Reduction)
+
+
 def test_struct_assignment(ctx_factory):
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)