diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index d24b61c12e43cd16431c9727d8fb057319475633..f9648bde7dc4d685ca9daf63ecf15b69496c8651 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -36,15 +36,19 @@ class ReductionOperation(object):
     equality-comparable.
     """
 
-    def result_dtypes(self, target, arg_dtype, inames):
+    def result_dtypes(self, target, *arg_dtypes):
         """
-        :arg arg_dtype: may be None if not known
+        :arg arg_dtypes: may be None if not known
         :returns: None if not known, otherwise the returned type
         """
 
         raise NotImplementedError
 
-    def neutral_element(self, dtype, inames):
+    @property
+    def arg_count(self):
+        raise NotImplementedError
+
+    def neutral_element(self, *dtypes):
         raise NotImplementedError
 
     def __hash__(self):
@@ -55,7 +59,7 @@ class ReductionOperation(object):
         # Force subclasses to override
         raise NotImplementedError
 
-    def __call__(self, dtype, operand1, operand2, inames):
+    def __call__(self, dtype, operand1, operand2):
         raise NotImplementedError
 
     def __ne__(self, other):
@@ -87,7 +91,11 @@ class ScalarReductionOperation(ReductionOperation):
         """
         self.forced_result_type = forced_result_type
 
-    def result_dtypes(self, kernel, arg_dtype, inames):
+    @property
+    def arg_count(self):
+        return 1
+
+    def result_dtypes(self, kernel, arg_dtype):
         if self.forced_result_type is not None:
             return (self.parse_result_type(
                     kernel.target, self.forced_result_type),)
@@ -114,18 +122,18 @@ class ScalarReductionOperation(ReductionOperation):
 
 
 class SumReductionOperation(ScalarReductionOperation):
-    def neutral_element(self, dtype, inames):
+    def neutral_element(self, dtype):
         return 0
 
-    def __call__(self, dtype, operand1, operand2, inames):
+    def __call__(self, dtype, operand1, operand2):
         return operand1 + operand2
 
 
 class ProductReductionOperation(ScalarReductionOperation):
-    def neutral_element(self, dtype, inames):
+    def neutral_element(self, dtype):
         return 1
 
-    def __call__(self, dtype, operand1, operand2, inames):
+    def __call__(self, dtype, operand1, operand2):
         return operand1 * operand2
 
 
@@ -166,32 +174,144 @@ def get_ge_neutral(dtype):
 
 
 class MaxReductionOperation(ScalarReductionOperation):
-    def neutral_element(self, dtype, inames):
+    def neutral_element(self, dtype):
         return get_ge_neutral(dtype)
 
-    def __call__(self, dtype, operand1, operand2, inames):
+    def __call__(self, dtype, operand1, operand2):
         return var("max")(operand1, operand2)
 
 
 class MinReductionOperation(ScalarReductionOperation):
-    def neutral_element(self, dtype, inames):
+    def neutral_element(self, dtype):
         return get_le_neutral(dtype)
 
-    def __call__(self, dtype, operand1, operand2, inames):
+    def __call__(self, dtype, operand1, operand2):
         return var("min")(operand1, operand2)
 
 
+# {{{ segmented reduction
+
+class _SegmentedScalarReductionOperation(ReductionOperation):
+    def __init__(self, **kwargs):
+        self.inner_reduction = self.base_reduction_class(**kwargs)
+
+    @property
+    def arg_count(self):
+        return 2
+
+    def prefix(self, scalar_dtype, segment_flag_dtype):
+        return "loopy_segmented_%s_%s_%s" % (self.which,
+                scalar_dtype.numpy_dtype.type.__name__,
+                segment_flag_dtype.numpy_dtype.type.__name__)
+
+    def neutral_element(self, scalar_dtype, segment_flag_dtype):
+        return SegmentedFunction(self, (scalar_dtype, segment_flag_dtype), "init")()
+
+    def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype):
+        return (self.inner_reduction.result_dtypes(kernel, scalar_dtype)
+                + (segment_flag_dtype,))
+
+    def __str__(self):
+        return "segmented(%s)" % self.which
+
+    def __hash__(self):
+        return hash(type(self))
+
+    def __eq__(self, other):
+        return type(self) == type(other)
+
+    def __call__(self, dtypes, operand1, operand2):
+        return SegmentedFunction(self, dtypes, "update")(*(operand1 + operand2))
+
+
+class SegmentedSumReductionOperation(_SegmentedScalarReductionOperation):
+    base_reduction_class = SumReductionOperation
+    which = "sum"
+    op = "((%s) + (%s))"
+
+
+class SegmentedProductReductionOperation(_SegmentedScalarReductionOperation):
+    base_reduction_class = ProductReductionOperation
+    op = "((%s) * (%s))"
+    which = "product"
+
+
+class SegmentedFunction(FunctionIdentifier):
+    init_arg_names = ("reduction_op", "dtypes", "name")
+
+    def __init__(self, reduction_op, dtypes, name):
+        """
+        :arg dtypes: A :class:`tuple` of `(scalar_dtype, segment_flag_dtype)`
+        """
+        self.reduction_op = reduction_op
+        self.dtypes = dtypes
+        self.name = name
+
+    @property
+    def scalar_dtype(self):
+        return self.dtypes[0]
+
+    @property
+    def segment_flag_dtype(self):
+        return self.dtypes[1]
+
+    def __getinitargs__(self):
+        return (self.reduction_op, self.dtypes, self.name)
+
+
+def get_segmented_function_preamble(kernel, func_id):
+    op = func_id.reduction_op
+    prefix = op.prefix(func_id.scalar_dtype, func_id.segment_flag_dtype)
+
+    from pymbolic.mapper.c_code import CCodeMapper
+
+    c_code_mapper = CCodeMapper()
+
+    return (prefix, """
+    inline %(scalar_t)s %(prefix)s_init(%(segment_flag_t)s *segment_flag_out)
+    {
+        *segment_flag_out = 0;
+        return %(neutral)s;
+    }
+
+    inline %(scalar_t)s %(prefix)s_update(
+        %(scalar_t)s op1, %(segment_flag_t)s segment_flag1,
+        %(scalar_t)s op2, %(segment_flag_t)s segment_flag2,
+        %(segment_flag_t)s *segment_flag_out)
+    {
+        *segment_flag_out = segment_flag1 | segment_flag2;
+        return segment_flag2 ? op2 : %(combined)s;
+    }
+    """ % dict(
+            scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype),
+            prefix=prefix,
+            segment_flag_t=kernel.target.dtype_to_typename(
+                    func_id.segment_flag_dtype),
+            neutral=c_code_mapper(
+                    op.inner_reduction.neutral_element(func_id.scalar_dtype)),
+            combined=op.op % ("op1", "op2"),
+            ))
+
+
+# }}}
+
+
 # {{{ argmin/argmax
 
 class _ArgExtremumReductionOperation(ReductionOperation):
-    def prefix(self, dtype):
-        return "loopy_arg%s_%s" % (self.which, dtype.numpy_dtype.type.__name__)
+    def prefix(self, scalar_dtype, index_dtype):
+        return "loopy_arg%s_%s_%s" % (self.which,
+                index_dtype.numpy_dtype.type.__name__,
+                scalar_dtype.numpy_dtype.type.__name__)
+
+    def result_dtypes(self, kernel, scalar_dtype, index_dtype):
+        return (scalar_dtype, index_dtype)
 
-    def result_dtypes(self, kernel, dtype, inames):
-        return (dtype, kernel.index_dtype)
+    def neutral_element(self, scalar_dtype, index_dtype):
+        return ArgExtFunction(self, (scalar_dtype, index_dtype), "init")()
 
-    def neutral_element(self, dtype, inames):
-        return ArgExtFunction(self, dtype, "init", inames)()
+    def __str__(self):
+        return self.which
 
     def __hash__(self):
         return hash(type(self))
@@ -199,11 +319,12 @@ class _ArgExtremumReductionOperation(ReductionOperation):
     def __eq__(self, other):
         return type(self) == type(other)
 
-    def __call__(self, dtype, operand1, operand2, inames):
-        iname, = inames
+    @property
+    def arg_count(self):
+        return 2
 
-        return ArgExtFunction(self, dtype, "update", inames)(
-                *(operand1 + (operand2, var(iname))))
+    def __call__(self, dtypes, operand1, operand2):
+        return ArgExtFunction(self, dtypes, "update")(*(operand1 + operand2))
 
 
 class ArgMaxReductionOperation(_ArgExtremumReductionOperation):
@@ -219,21 +340,28 @@ class ArgMinReductionOperation(_ArgExtremumReductionOperation):
 
 
 class ArgExtFunction(FunctionIdentifier):
-    init_arg_names = ("reduction_op", "scalar_dtype", "name", "inames")
+    init_arg_names = ("reduction_op", "dtypes", "name")
 
-    def __init__(self, reduction_op, scalar_dtype, name, inames):
+    def __init__(self, reduction_op, dtypes, name):
         self.reduction_op = reduction_op
-        self.scalar_dtype = scalar_dtype
+        self.dtypes = dtypes
         self.name = name
-        self.inames = inames
+
+    @property
+    def scalar_dtype(self):
+        return self.dtypes[0]
+
+    @property
+    def index_dtype(self):
+        return self.dtypes[1]
 
     def __getinitargs__(self):
-        return (self.reduction_op, self.scalar_dtype, self.name, self.inames)
+        return (self.reduction_op, self.dtypes, self.name)
 
 
 def get_argext_preamble(kernel, func_id):
     op = func_id.reduction_op
-    prefix = op.prefix(func_id.scalar_dtype)
+    prefix = op.prefix(func_id.scalar_dtype, func_id.index_dtype)
 
     from pymbolic.mapper.c_code import CCodeMapper
 
@@ -267,7 +395,7 @@ def get_argext_preamble(kernel, func_id):
     """ % dict(
             scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype),
             prefix=prefix,
-            index_t=kernel.target.dtype_to_typename(kernel.index_dtype),
+            index_t=kernel.target.dtype_to_typename(func_id.index_dtype),
             neutral=c_code_mapper(neutral(func_id.scalar_dtype)),
             comp=op.update_comparison,
             ))
@@ -284,6 +412,8 @@ _REDUCTION_OPS = {
         "min": MinReductionOperation,
         "argmax": ArgMaxReductionOperation,
         "argmin": ArgMinReductionOperation,
+        "segmented(sum)": SegmentedSumReductionOperation,
+        "segmented(product)": SegmentedProductReductionOperation,
         }
 
 _REDUCTION_OP_PARSERS = [
@@ -325,32 +455,34 @@ def parse_reduction_op(name):
 
 def reduction_function_mangler(kernel, func_id, arg_dtypes):
     if isinstance(func_id, ArgExtFunction) and func_id.name == "init":
-        from loopy.target.opencl import OpenCLTarget
-        if not isinstance(kernel.target, OpenCLTarget):
-            raise LoopyError("only OpenCL supported for now")
+        from loopy.target.opencl import CTarget
+        if not isinstance(kernel.target, CTarget):
+            raise LoopyError("%s: only C-like targets supported for now" % func_id)
 
         op = func_id.reduction_op
 
         from loopy.kernel.data import CallMangleInfo
         return CallMangleInfo(
-                target_name="%s_init" % op.prefix(func_id.scalar_dtype),
+                target_name="%s_init" % op.prefix(
+                    func_id.scalar_dtype, func_id.index_dtype),
                 result_dtypes=op.result_dtypes(
-                    kernel, func_id.scalar_dtype, func_id.inames),
+                    kernel, func_id.scalar_dtype, func_id.index_dtype),
                 arg_dtypes=(),
                 )
 
     elif isinstance(func_id, ArgExtFunction) and func_id.name == "update":
-        from loopy.target.opencl import OpenCLTarget
-        if not isinstance(kernel.target, OpenCLTarget):
-            raise LoopyError("only OpenCL supported for now")
+        from loopy.target.opencl import CTarget
+        if not isinstance(kernel.target, CTarget):
+            raise LoopyError("%s: only C-like targets supported for now" % func_id)
 
         op = func_id.reduction_op
 
         from loopy.kernel.data import CallMangleInfo
         return CallMangleInfo(
-                target_name="%s_update" % op.prefix(func_id.scalar_dtype),
+                target_name="%s_update" % op.prefix(
+                    func_id.scalar_dtype, func_id.index_dtype),
                 result_dtypes=op.result_dtypes(
-                    kernel, func_id.scalar_dtype, func_id.inames),
+                    kernel, func_id.scalar_dtype, func_id.index_dtype),
                 arg_dtypes=(
                     func_id.scalar_dtype,
                     kernel.index_dtype,
@@ -358,6 +490,42 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes):
                     kernel.index_dtype),
                 )
 
+    elif isinstance(func_id, SegmentedFunction) and func_id.name == "init":
+        from loopy.target.opencl import CTarget
+        if not isinstance(kernel.target, CTarget):
+            raise LoopyError("%s: only C-like targets supported for now" % func_id)
+
+        op = func_id.reduction_op
+
+        from loopy.kernel.data import CallMangleInfo
+        return CallMangleInfo(
+                target_name="%s_init" % op.prefix(
+                    func_id.scalar_dtype, func_id.segment_flag_dtype),
+                result_dtypes=op.result_dtypes(
+                    kernel, func_id.scalar_dtype, func_id.segment_flag_dtype),
+                arg_dtypes=(),
+                )
+
+    elif isinstance(func_id, SegmentedFunction) and func_id.name == "update":
+        from loopy.target.opencl import CTarget
+        if not isinstance(kernel.target, CTarget):
+            raise LoopyError("%s: only C-like targets supported for now" % func_id)
+
+        op = func_id.reduction_op
+
+        from loopy.kernel.data import CallMangleInfo
+        return CallMangleInfo(
+                target_name="%s_update" % op.prefix(
+                    func_id.scalar_dtype, func_id.segment_flag_dtype),
+                result_dtypes=op.result_dtypes(
+                    kernel, func_id.scalar_dtype, func_id.segment_flag_dtype),
+                arg_dtypes=(
+                    func_id.scalar_dtype,
+                    func_id.segment_flag_dtype,
+                    func_id.scalar_dtype,
+                    func_id.segment_flag_dtype),
+                )
+
     return None
 
 
@@ -371,4 +539,10 @@ def reduction_preamble_generator(preamble_info):
 
             yield get_argext_preamble(preamble_info.kernel, func.name)
 
+        elif isinstance(func.name, SegmentedFunction):
+            if not isinstance(preamble_info.kernel.target, OpenCLTarget):
+                raise LoopyError("only OpenCL supported for now")
+
+            yield get_segmented_function_preamble(preamble_info.kernel, func.name)
+
 # vim: fdm=marker
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 0d8e771954cf26cc11747e745946389420fa5e1b..17226b63addb9e2e30d556730aa326d2ed59128c 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -26,7 +26,7 @@ THE SOFTWARE.
 import six
 from loopy.diagnostic import (
         LoopyError, WriteRaceConditionWarning, warn_with_kernel,
-        LoopyAdvisory, DependencyTypeInferenceFailure)
+        LoopyAdvisory)
 
 import islpy as isl
 
@@ -98,6 +98,7 @@ def check_reduction_iname_uniqueness(kernel):
 
     def map_reduction(expr, rec):
         rec(expr.expr)
+
         for iname in expr.inames:
             iname_to_reduction_count[iname] = (
                     iname_to_reduction_count.get(iname, 0) + 1)
@@ -272,6 +273,191 @@ def find_temporary_scope(kernel):
 
 # {{{ rewrite reduction to imperative form
 
+
+# {{{ reduction utils
+
+def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel):
+    """
+    Multi assignment function calls are currently lowered into OpenCL so that
+    the function call::
+
+       a, b = segmented_sum(x, y, z, w)
+
+    becomes::
+
+       a = segmented_sum_mangled(x, y, z, w, &b).
+
+    For OpenCL, the scope of "b" is significant, and the preamble generation
+    currently assumes the scope is always private. This function forces that to
+    be the case by introducing temporary assignments into the kernel.
+    """
+
+    insn_id_gen = kernel.get_instruction_id_generator()
+    var_name_gen = kernel.get_var_name_generator()
+
+    new_or_updated_instructions = {}
+    new_temporaries = {}
+
+    dep_map = dict(
+            (insn.id, insn.depends_on) for insn in kernel.instructions)
+
+    inverse_dep_map = dict((insn.id, set()) for insn in kernel.instructions)
+
+    import six
+    for insn_id, deps in six.iteritems(dep_map):
+        for dep in deps:
+            inverse_dep_map[dep].add(insn_id)
+
+    del dep_map
+
+    # {{{ utils
+
+    def _add_to_no_sync_with(insn_id, new_no_sync_with_params):
+        insn = kernel.id_to_insn.get(insn_id)
+        insn = new_or_updated_instructions.get(insn_id, insn)
+        new_or_updated_instructions[insn_id] = (
+                insn.copy(
+                    no_sync_with=(
+                        insn.no_sync_with | frozenset(new_no_sync_with_params))))
+
+    def _add_to_depends_on(insn_id, new_depends_on_params):
+        insn = kernel.id_to_insn.get(insn_id)
+        insn = new_or_updated_instructions.get(insn_id, insn)
+        new_or_updated_instructions[insn_id] = (
+                insn.copy(
+                    depends_on=insn.depends_on | frozenset(new_depends_on_params)))
+
+    # }}}
+
+    from loopy.kernel.instruction import CallInstruction
+    for insn in kernel.instructions:
+        if not isinstance(insn, CallInstruction):
+            continue
+
+        if len(insn.assignees) <= 1:
+            continue
+
+        assignees = insn.assignees
+        assignee_var_names = insn.assignee_var_names()
+
+        new_assignees = [assignees[0]]
+        newly_added_assignments_ids = set()
+        needs_replacement = False
+
+        last_added_insn_id = insn.id
+
+        from loopy.kernel.data import temp_var_scope, TemporaryVariable
+
+        FIRST_POINTER_ASSIGNEE_IDX = 1  # noqa
+
+        for assignee_nr, assignee_var_name, assignee in zip(
+                range(FIRST_POINTER_ASSIGNEE_IDX, len(assignees)),
+                assignee_var_names[FIRST_POINTER_ASSIGNEE_IDX:],
+                assignees[FIRST_POINTER_ASSIGNEE_IDX:]):
+
+            if (
+                    assignee_var_name in kernel.temporary_variables
+                    and
+                    (kernel.temporary_variables[assignee_var_name].scope
+                         == temp_var_scope.PRIVATE)):
+                new_assignees.append(assignee)
+                continue
+
+            needs_replacement = True
+
+            # {{{ generate a new assignent instruction
+
+            new_assignee_name = var_name_gen(
+                    "{insn_id}_retval_{assignee_nr}"
+                    .format(insn_id=insn.id, assignee_nr=assignee_nr))
+
+            new_assignment_id = insn_id_gen(
+                    "{insn_id}_assign_retval_{assignee_nr}"
+                    .format(insn_id=insn.id, assignee_nr=assignee_nr))
+
+            newly_added_assignments_ids.add(new_assignment_id)
+
+            import loopy as lp
+            new_temporaries[new_assignee_name] = (
+                    TemporaryVariable(
+                        name=new_assignee_name,
+                        dtype=lp.auto,
+                        scope=temp_var_scope.PRIVATE))
+
+            from pymbolic import var
+            new_assignee = var(new_assignee_name)
+            new_assignees.append(new_assignee)
+
+            new_or_updated_instructions[new_assignment_id] = (
+                    make_assignment(
+                        assignees=(assignee,),
+                        expression=new_assignee,
+                        id=new_assignment_id,
+                        depends_on=frozenset([last_added_insn_id]),
+                        depends_on_is_final=True,
+                        no_sync_with=(
+                            insn.no_sync_with | frozenset([(insn.id, "any")])),
+                        predicates=insn.predicates,
+                        within_inames=insn.within_inames))
+
+            last_added_insn_id = new_assignment_id
+
+            # }}}
+
+        if not needs_replacement:
+            continue
+
+        # {{{ update originating instruction
+
+        orig_insn = new_or_updated_instructions.get(insn.id, insn)
+
+        new_or_updated_instructions[insn.id] = (
+                orig_insn.copy(assignees=tuple(new_assignees)))
+
+        _add_to_no_sync_with(insn.id,
+                [(id, "any") for id in newly_added_assignments_ids])
+
+        # }}}
+
+        # {{{ squash spurious memory dependencies amongst new assignments
+
+        for new_insn_id in newly_added_assignments_ids:
+            _add_to_no_sync_with(new_insn_id,
+                    [(id, "any")
+                     for id in newly_added_assignments_ids
+                     if id != new_insn_id])
+
+        # }}}
+
+        # {{{ update instructions that depend on the originating instruction
+
+        for inverse_dep in inverse_dep_map[insn.id]:
+            _add_to_depends_on(inverse_dep, newly_added_assignments_ids)
+
+            for insn_id, scope in (
+                    new_or_updated_instructions[inverse_dep].no_sync_with):
+                if insn_id == insn.id:
+                    _add_to_no_sync_with(
+                            inverse_dep,
+                            [(id, scope) for id in newly_added_assignments_ids])
+
+        # }}}
+
+    new_temporary_variables = kernel.temporary_variables.copy()
+    new_temporary_variables.update(new_temporaries)
+
+    new_instructions = (
+            list(new_or_updated_instructions.values())
+            + list(insn
+                for insn in kernel.instructions
+                if insn.id not in new_or_updated_instructions))
+
+    return kernel.copy(temporary_variables=new_temporary_variables,
+                       instructions=new_instructions)
+
+# }}}
+
+
 def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
     """Rewrites reductions into their imperative form. With *insn_id_filter*
     specified, operate only on the instruction with an instruction id matching
@@ -295,12 +481,52 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
     var_name_gen = kernel.get_var_name_generator()
     new_temporary_variables = kernel.temporary_variables.copy()
 
-    from loopy.type_inference import TypeInferenceMapper
-    type_inf_mapper = TypeInferenceMapper(kernel)
+    # {{{ helpers
+
+    def _strip_if_scalar(reference, val):
+        if len(reference) == 1:
+            return val[0]
+        else:
+            return val
+
+    def expand_inner_reduction(id, expr, nresults, depends_on, within_inames,
+            within_inames_is_final):
+        from pymbolic.primitives import Call
+        from loopy.symbolic import Reduction
+        assert isinstance(expr, (Call, Reduction))
+
+        temp_var_names = [
+                var_name_gen(id + "_arg" + str(i))
+                for i in range(nresults)]
+
+        for name in temp_var_names:
+            from loopy.kernel.data import TemporaryVariable, temp_var_scope
+            new_temporary_variables[name] = TemporaryVariable(
+                    name=name,
+                    shape=(),
+                    dtype=lp.auto,
+                    scope=temp_var_scope.PRIVATE)
+
+        from pymbolic import var
+        temp_vars = tuple(var(n) for n in temp_var_names)
+
+        call_insn = make_assignment(
+                id=id,
+                assignees=temp_vars,
+                expression=expr,
+                depends_on=depends_on,
+                within_inames=within_inames,
+                within_inames_is_final=within_inames_is_final)
+
+        generated_insns.append(call_insn)
+
+        return temp_vars
+
+    # }}}
 
     # {{{ sequential
 
-    def map_reduction_seq(expr, rec, nresults, arg_dtype,
+    def map_reduction_seq(expr, rec, nresults, arg_dtypes,
             reduction_dtypes):
         outer_insn_inames = temp_kernel.insn_inames(insn)
 
@@ -328,7 +554,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 within_inames=outer_insn_inames - frozenset(expr.inames),
                 within_inames_is_final=insn.within_inames_is_final,
                 depends_on=frozenset(),
-                expression=expr.operation.neutral_element(arg_dtype, expr.inames))
+                expression=expr.operation.neutral_element(*arg_dtypes))
 
         generated_insns.append(init_insn)
 
@@ -339,14 +565,32 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
         if insn.within_inames_is_final:
             update_insn_iname_deps = insn.within_inames | set(expr.inames)
 
+        reduction_insn_depends_on = set([init_id])
+
+        if nresults > 1 and not isinstance(expr.expr, tuple):
+            get_args_insn_id = insn_id_gen(
+                    "%s_%s_get" % (insn.id, "_".join(expr.inames)))
+
+            reduction_expr = expand_inner_reduction(
+                id=get_args_insn_id,
+                expr=expr.expr,
+                nresults=nresults,
+                depends_on=insn.depends_on,
+                within_inames=update_insn_iname_deps,
+                within_inames_is_final=insn.within_inames_is_final)
+
+            reduction_insn_depends_on.add(get_args_insn_id)
+        else:
+            reduction_expr = expr.expr
+
         reduction_insn = make_assignment(
                 id=update_id,
                 assignees=acc_vars,
                 expression=expr.operation(
-                    arg_dtype,
-                    acc_vars if len(acc_vars) > 1 else acc_vars[0],
-                    expr.expr, expr.inames),
-                depends_on=frozenset([init_insn.id]) | insn.depends_on,
+                    arg_dtypes,
+                    _strip_if_scalar(acc_vars, acc_vars),
+                    reduction_expr),
+                depends_on=frozenset(reduction_insn_depends_on) | insn.depends_on,
                 within_inames=update_insn_iname_deps,
                 within_inames_is_final=insn.within_inames_is_final)
 
@@ -382,7 +626,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 v[iname].lt_set(v[0] + size)).get_basic_sets()
         return bs
 
-    def map_reduction_local(expr, rec, nresults, arg_dtype,
+    def map_reduction_local(expr, rec, nresults, arg_dtypes,
             reduction_dtypes):
         red_iname, = expr.inames
 
@@ -441,7 +685,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         base_iname_deps = outer_insn_inames - frozenset(expr.inames)
 
-        neutral = expr.operation.neutral_element(arg_dtype, expr.inames)
+        neutral = expr.operation.neutral_element(*arg_dtypes)
 
         init_id = insn_id_gen("%s_%s_init" % (insn.id, red_iname))
         init_insn = make_assignment(
@@ -455,12 +699,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 depends_on=frozenset())
         generated_insns.append(init_insn)
 
-        def _strip_if_scalar(c):
-            if len(acc_vars) == 1:
-                return c[0]
-            else:
-                return c
-
         init_neutral_id = insn_id_gen("%s_%s_init_neutral" % (insn.id, red_iname))
         init_neutral_insn = make_assignment(
                 id=init_neutral_id,
@@ -471,6 +709,26 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 depends_on=frozenset())
         generated_insns.append(init_neutral_insn)
 
+        transfer_depends_on = set([init_neutral_id, init_id])
+
+        if nresults > 1 and not isinstance(expr.expr, tuple):
+            get_args_insn_id = insn_id_gen(
+                    "%s_%s_get" % (insn.id, red_iname))
+
+            reduction_expr = expand_inner_reduction(
+                    id=get_args_insn_id,
+                    expr=expr.expr,
+                    nresults=nresults,
+                    depends_on=insn.depends_on,
+                    within_inames=(
+                        (outer_insn_inames - frozenset(expr.inames))
+                        | frozenset([red_iname])),
+                    within_inames_is_final=insn.within_inames_is_final)
+
+            transfer_depends_on.add(get_args_insn_id)
+        else:
+            reduction_expr = expr.expr
+
         transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname))
         transfer_insn = make_assignment(
                 id=transfer_id,
@@ -478,15 +736,18 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                     acc_var[outer_local_iname_vars + (var(red_iname),)]
                     for acc_var in acc_vars),
                 expression=expr.operation(
-                    arg_dtype,
-                    _strip_if_scalar(tuple(var(nvn) for nvn in neutral_var_names)),
-                    expr.expr, expr.inames),
+                    arg_dtypes,
+                    _strip_if_scalar(
+                        neutral_var_names,
+                        tuple(var(nvn) for nvn in neutral_var_names)),
+                    reduction_expr),
                 within_inames=(
                     (outer_insn_inames - frozenset(expr.inames))
                     | frozenset([red_iname])),
                 within_inames_is_final=insn.within_inames_is_final,
-                depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on,
-                no_sync_with=frozenset([(init_id, "any")]))
+                depends_on=frozenset(transfer_depends_on) | insn.depends_on,
+                no_sync_with=frozenset(
+                    [(insn_id, "any") for insn_id in transfer_depends_on]))
         generated_insns.append(transfer_insn)
 
         cur_size = 1
@@ -498,7 +759,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         istage = 0
         while cur_size > 1:
-
             new_size = cur_size // 2
             assert new_size * 2 == cur_size
 
@@ -513,17 +773,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                         acc_var[outer_local_iname_vars + (var(stage_exec_iname),)]
                         for acc_var in acc_vars),
                     expression=expr.operation(
-                        arg_dtype,
-                        _strip_if_scalar(tuple(
+                        arg_dtypes,
+                        _strip_if_scalar(acc_vars, tuple(
                             acc_var[
                                 outer_local_iname_vars + (var(stage_exec_iname),)]
                             for acc_var in acc_vars)),
-                        _strip_if_scalar(tuple(
+                        _strip_if_scalar(acc_vars, tuple(
                             acc_var[
                                 outer_local_iname_vars + (
                                     var(stage_exec_iname) + new_size,)]
-                            for acc_var in acc_vars)),
-                        expr.inames),
+                            for acc_var in acc_vars))),
                     within_inames=(
                         base_iname_deps | frozenset([stage_exec_iname])),
                     within_inames_is_final=insn.within_inames_is_final,
@@ -554,24 +813,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
         # Only expand one level of reduction at a time, going from outermost to
         # innermost. Otherwise we get the (iname + insn) dependencies wrong.
 
-        try:
-            arg_dtype = type_inf_mapper(expr.expr)
-        except DependencyTypeInferenceFailure:
-            if unknown_types_ok:
-                arg_dtype = lp.auto
-
-                reduction_dtypes = (lp.auto,)*nresults
-
-            else:
-                raise LoopyError("failed to determine type of accumulator for "
-                        "reduction '%s'" % expr)
-        else:
-            arg_dtype = arg_dtype.with_target(kernel.target)
-
-            reduction_dtypes = expr.operation.result_dtypes(
-                        kernel, arg_dtype, expr.inames)
-            reduction_dtypes = tuple(
-                    dt.with_target(kernel.target) for dt in reduction_dtypes)
+        from loopy.type_inference import (
+                infer_arg_and_reduction_dtypes_for_reduction_expression)
+        arg_dtypes, reduction_dtypes = (
+                infer_arg_and_reduction_dtypes_for_reduction_expression(
+                        temp_kernel, expr, unknown_types_ok))
 
         outer_insn_inames = temp_kernel.insn_inames(insn)
         bad_inames = frozenset(expr.inames) & outer_insn_inames
@@ -621,10 +867,10 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         if n_sequential:
             assert n_local_par == 0
-            return map_reduction_seq(expr, rec, nresults, arg_dtype,
+            return map_reduction_seq(expr, rec, nresults, arg_dtypes,
                     reduction_dtypes)
         elif n_local_par:
-            return map_reduction_local(expr, rec, nresults, arg_dtype,
+            return map_reduction_local(expr, rec, nresults, arg_dtypes,
                     reduction_dtypes)
         else:
             from loopy.diagnostic import warn_with_kernel
@@ -739,6 +985,10 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
     kernel = lp.tag_inames(kernel, new_iname_tags)
 
+    kernel = (
+            _hackily_ensure_multi_assignment_return_values_are_scoped_private(
+                kernel))
+
     return kernel
 
 # }}}
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 50c891be476810887720c4e13c9659966b431f5d..f1a494f30d469511817d204c0476ff79abe00e3b 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -95,7 +95,8 @@ class IdentityMapperMixin(object):
             new_inames.append(new_sym_iname.name)
 
         return Reduction(
-                expr.operation, tuple(new_inames), self.rec(expr.expr, *args),
+                expr.operation, tuple(new_inames),
+                self.rec(expr.expr, *args),
                 allow_simultaneous=expr.allow_simultaneous)
 
     def map_tagged_variable(self, expr, *args):
@@ -192,9 +193,12 @@ class StringifyMapper(StringifyMapperBase):
         return "loc.%d" % expr.index
 
     def map_reduction(self, expr, prec):
+        from pymbolic.mapper.stringifier import PREC_NONE
+
         return "%sreduce(%s, [%s], %s)" % (
                 "simul_" if expr.allow_simultaneous else "",
-                expr.operation, ", ".join(expr.inames), expr.expr)
+                expr.operation, ", ".join(expr.inames),
+                self.rec(expr.expr, PREC_NONE))
 
     def map_tagged_variable(self, expr, prec):
         return "%s$%s" % (expr.name, expr.tag)
@@ -258,8 +262,8 @@ class DependencyMapper(DependencyMapperBase):
                 self.rec(child, *args) for child in expr.parameters)
 
     def map_reduction(self, expr):
-        return (self.rec(expr.expr)
-                - set(p.Variable(iname) for iname in expr.inames))
+        deps = self.rec(expr.expr)
+        return deps - set(p.Variable(iname) for iname in expr.inames)
 
     def map_tagged_variable(self, expr):
         return set([expr])
@@ -428,7 +432,7 @@ class TaggedVariable(p.Variable):
 
 
 class Reduction(p.Expression):
-    """Represents a reduction operation on :attr:`expr`
+    """Represents a reduction operation on :attr:`exprs`
     across :attr:`inames`.
 
     .. attribute:: operation
@@ -442,8 +446,11 @@ class Reduction(p.Expression):
 
     .. attribute:: expr
 
-        The expression (as a :class:`pymbolic.primitives.Expression`)
-        on which reduction is performed.
+        An expression which may have tuple type. If the expression has tuple
+        type, it must be one of the following:
+         * a :class:`tuple` of :class:`pymbolic.primitives.Expression`, or
+         * a :class:`loopy.symbolic.Reduction`, or
+         * a function call or substitution rule invocation.
 
     .. attribute:: allow_simultaneous
 
@@ -478,6 +485,22 @@ class Reduction(p.Expression):
         from loopy.library.reduction import ReductionOperation
         assert isinstance(operation, ReductionOperation)
 
+        from loopy.diagnostic import LoopyError
+
+        if operation.arg_count > 1:
+            from pymbolic.primitives import Call
+
+            if not isinstance(expr, (tuple, Reduction, Call)):
+                raise LoopyError("reduction argument must be one of "
+                                 "a tuple, reduction, or call; "
+                                 "got '%s'" % type(expr).__name__)
+        else:
+            # Sanity checks
+            if isinstance(expr, tuple):
+                raise LoopyError("got a tuple argument to a scalar reduction")
+            elif isinstance(expr, Reduction) and expr.is_tuple_typed:
+                raise LoopyError("got a tuple typed argument to a scalar reduction")
+
         self.operation = operation
         self.inames = inames
         self.expr = expr
@@ -487,8 +510,7 @@ class Reduction(p.Expression):
         return (self.operation, self.inames, self.expr, self.allow_simultaneous)
 
     def get_hash(self):
-        return hash((self.__class__, self.operation, self.inames,
-            self.expr))
+        return hash((self.__class__, self.operation, self.inames, self.expr))
 
     def is_equal(self, other):
         return (other.__class__ == self.__class__
@@ -499,6 +521,10 @@ class Reduction(p.Expression):
     def stringifier(self):
         return StringifyMapper
 
+    @property
+    def is_tuple_typed(self):
+        return self.operation.arg_count > 1
+
     @property
     @memoize_method
     def inames_set(self):
@@ -924,7 +950,7 @@ class FunctionToPrimitiveMapper(IdentityMapper):
     turns those into the actual pymbolic primitives used for that.
     """
 
-    def _parse_reduction(self, operation, inames, red_expr,
+    def _parse_reduction(self, operation, inames, red_exprs,
             allow_simultaneous=False):
         if isinstance(inames, p.Variable):
             inames = (inames,)
@@ -941,7 +967,10 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 
             processed_inames.append(iname.name)
 
-        return Reduction(operation, tuple(processed_inames), red_expr,
+        if len(red_exprs) == 1:
+            red_exprs = red_exprs[0]
+
+        return Reduction(operation, tuple(processed_inames), red_exprs,
                 allow_simultaneous=allow_simultaneous)
 
     def map_call(self, expr):
@@ -966,15 +995,14 @@ class FunctionToPrimitiveMapper(IdentityMapper):
                 raise TypeError("cse takes two arguments")
 
         elif name in ["reduce", "simul_reduce"]:
-            if len(expr.parameters) == 3:
-                operation, inames, red_expr = expr.parameters
 
-                if not isinstance(operation, p.Variable):
-                    raise TypeError("operation argument to reduce() "
-                            "must be a symbol")
+            if len(expr.parameters) >= 3:
+                operation, inames = expr.parameters[:2]
+                red_exprs = expr.parameters[2:]
 
-                operation = parse_reduction_op(operation.name)
-                return self._parse_reduction(operation, inames, self.rec(red_expr),
+                operation = parse_reduction_op(str(operation))
+                return self._parse_reduction(operation, inames,
+                        tuple(self.rec(red_expr) for red_expr in red_exprs),
                         allow_simultaneous=(name == "simul_reduce"))
             else:
                 raise TypeError("invalid 'reduce' calling sequence")
@@ -991,12 +1019,17 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 
             operation = parse_reduction_op(name)
             if operation:
-                if len(expr.parameters) != 2:
+                # arg_count counts arguments but not inames
+                if len(expr.parameters) != 1 + operation.arg_count:
                     raise RuntimeError("invalid invocation of "
-                            "reduction operation '%s'" % expr.function.name)
-
-                inames, red_expr = expr.parameters
-                return self._parse_reduction(operation, inames, self.rec(red_expr))
+                            "reduction operation '%s': expected %d arguments, "
+                            "got %d instead" % (expr.function.name,
+                                                1 + operation.arg_count,
+                                                len(expr.parameters)))
+
+                inames = expr.parameters[0]
+                red_exprs = tuple(self.rec(param) for param in expr.parameters[1:])
+                return self._parse_reduction(operation, inames, red_exprs)
 
             else:
                 return IdentityMapper.map_call(self, expr)
diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index 4c1e423e93e104fecd0b49a2b1ef2b4a261e38e7..b8b0cbcbf1236cdf712da998922ac238261a3e6e 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -352,28 +352,39 @@ class TypeInferenceMapper(CombineMapper):
         return [self.kernel.index_dtype]
 
     def map_reduction(self, expr, return_tuple=False):
-        rec_result = self.rec(expr.expr)
-
-        if rec_result:
-            rec_result, = rec_result
-            result = expr.operation.result_dtypes(
-                    self.kernel, rec_result, expr.inames)
+        """
+        :arg return_tuple: If *True*, treat the reduction as having tuple type.
+        Otherwise, if *False*, the reduction must have scalar type.
+        """
+        from loopy.symbolic import Reduction
+        from pymbolic.primitives import Call
+
+        if not return_tuple and expr.is_tuple_typed:
+            raise LoopyError("reductions with more or fewer than one "
+                             "return value may only be used in direct "
+                             "assignments")
+
+        if isinstance(expr.expr, tuple):
+            rec_results = [self.rec(sub_expr) for sub_expr in expr.expr]
+            from itertools import product
+            rec_results = product(*rec_results)
+        elif isinstance(expr.expr, Reduction):
+            rec_results = self.rec(expr.expr, return_tuple=return_tuple)
+        elif isinstance(expr.expr, Call):
+            rec_results = self.map_call(expr.expr, return_tuple=return_tuple)
         else:
-            result = expr.operation.result_dtypes(
-                    self.kernel, None, expr.inames)
-
-        if result is None:
-            return []
+            if return_tuple:
+                raise LoopyError("unknown reduction type for tuple reduction: '%s'"
+                        % type(expr.expr).__name__)
+            else:
+                rec_results = self.rec(expr.expr)
 
         if return_tuple:
-            return [result]
-
+            return [expr.operation.result_dtypes(self.kernel, *rec_result)
+                    for rec_result in rec_results]
         else:
-            if len(result) != 1 and not return_tuple:
-                raise LoopyError("reductions with more or fewer than one "
-                        "return value may only be used in direct assignments")
-
-            return [result[0]]
+            return [expr.operation.result_dtypes(self.kernel, rec_result)[0]
+                    for rec_result in rec_results]
 
 # }}}
 
@@ -617,4 +628,44 @@ def infer_unknown_types(kernel, expect_completion=False):
 
 # }}}
 
+
+# {{{ reduction expression helper
+
+def infer_arg_and_reduction_dtypes_for_reduction_expression(
+        kernel, expr, unknown_types_ok):
+    type_inf_mapper = TypeInferenceMapper(kernel)
+    import loopy as lp
+
+    if expr.is_tuple_typed:
+        arg_dtypes_result = type_inf_mapper(
+                expr, return_tuple=True, return_dtype_set=True)
+
+        if len(arg_dtypes_result) == 1:
+            arg_dtypes = arg_dtypes_result[0]
+        else:
+            if unknown_types_ok:
+                arg_dtypes = [lp.auto] * expr.operation.arg_count
+            else:
+                raise LoopyError("failed to determine types of accumulators for "
+                        "reduction '%s'" % expr)
+    else:
+        try:
+            arg_dtypes = [type_inf_mapper(expr)]
+        except DependencyTypeInferenceFailure:
+            if unknown_types_ok:
+                arg_dtypes = [lp.auto]
+            else:
+                raise LoopyError("failed to determine type of accumulator for "
+                        "reduction '%s'" % expr)
+
+    reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes)
+    reduction_dtypes = tuple(
+            dt.with_target(kernel.target)
+            if dt is not lp.auto else dt
+            for dt in reduction_dtypes)
+
+    return tuple(arg_dtypes), reduction_dtypes
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 1218847a7c42bd420a993d86a7534f066c2ab20e..4bb6a27267bd7b1880265bdd5b47ee676a480fb3 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1987,19 +1987,28 @@ def test_integer_reduction(ctx_factory):
                                    dtype=to_loopy_type(vtype),
                                    shape=lp.auto)
 
-        reductions = [('max', lambda x: x == np.max(var_int)),
-                      ('min', lambda x: x == np.min(var_int)),
-                      ('sum', lambda x: x == np.sum(var_int)),
-                      ('product', lambda x: x == np.prod(var_int)),
-                      ('argmax', lambda x: (x[0] == np.max(var_int) and
-                        var_int[out[1]] == np.max(var_int))),
-                      ('argmin', lambda x: (x[0] == np.min(var_int) and
-                        var_int[out[1]] == np.min(var_int)))]
-
-        for reduction, function in reductions:
+        from collections import namedtuple
+        ReductionTest = namedtuple('ReductionTest', 'kind, check, args')
+
+        reductions = [
+            ReductionTest('max', lambda x: x == np.max(var_int), args='var[k]'),
+            ReductionTest('min', lambda x: x == np.min(var_int), args='var[k]'),
+            ReductionTest('sum', lambda x: x == np.sum(var_int), args='var[k]'),
+            ReductionTest('product', lambda x: x == np.prod(var_int), args='var[k]'),
+            ReductionTest('argmax',
+                lambda x: (
+                    x[0] == np.max(var_int) and var_int[out[1]] == np.max(var_int)),
+                args='var[k], k'),
+            ReductionTest('argmin',
+                lambda x: (
+                    x[0] == np.min(var_int) and var_int[out[1]] == np.min(var_int)),
+                args='var[k], k')
+        ]
+
+        for reduction, function, args in reductions:
             kstr = ("out" if 'arg' not in reduction
                         else "out[0], out[1]")
-            kstr += ' = {0}(k, var[k])'.format(reduction)
+            kstr += ' = {0}(k, {1})'.format(reduction, args)
             knl = lp.make_kernel('{[k]: 0<=k<n}',
                                 kstr,
                                 [var_lp, '...'])
@@ -2152,6 +2161,41 @@ def test_global_barrier_error_if_unordered():
         lp.get_global_barrier_order(knl)
 
 
+def test_multi_argument_reduction_type_inference():
+    from loopy.type_inference import TypeInferenceMapper
+    from loopy.library.reduction import SegmentedSumReductionOperation
+    from loopy.types import to_loopy_type
+    op = SegmentedSumReductionOperation()
+
+    knl = lp.make_kernel("{[i,j]: 0<=i<10 and 0<=j<i}", "")
+
+    int32 = to_loopy_type(np.int32)
+
+    expr = lp.symbolic.Reduction(
+            operation=op,
+            inames=("i",),
+            expr=lp.symbolic.Reduction(
+                operation=op,
+                inames="j",
+                expr=(1, 2),
+                allow_simultaneous=True),
+            allow_simultaneous=True)
+
+    t_inf_mapper = TypeInferenceMapper(knl)
+
+    assert (
+            t_inf_mapper(expr, return_tuple=True, return_dtype_set=True)
+            == [(int32, int32)])
+
+
+def test_multi_argument_reduction_parsing():
+    from loopy.symbolic import parse, Reduction
+
+    assert isinstance(
+            parse("reduce(argmax, i, reduce(argmax, j, i, j))").expr,
+            Reduction)
+
+
 def test_struct_assignment(ctx_factory):
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
diff --git a/test/test_reduction.py b/test/test_reduction.py
index 86e72c0c6644b7b9837a6d74da756c58344b1d6f..be11d7c8cada94596dceb1a8e0e678f8adb582e9 100644
--- a/test/test_reduction.py
+++ b/test/test_reduction.py
@@ -297,7 +297,7 @@ def test_argmax(ctx_factory):
     knl = lp.make_kernel(
             "{[i]: 0<=i<%d}" % n,
             """
-            max_val, max_idx = argmax(i, fabs(a[i]))
+            max_val, max_idx = argmax(i, fabs(a[i]), i)
             """)
 
     knl = lp.add_and_infer_dtypes(knl, {"a": np.float32})
@@ -393,16 +393,24 @@ def test_double_sum_made_unique(ctx_factory):
     assert b.get() == ref
 
 
-def test_parallel_multi_output_reduction():
+def test_parallel_multi_output_reduction(ctx_factory):
     knl = lp.make_kernel(
                 "{[i]: 0<=i<128}",
                 """
-                max_val, max_indices = argmax(i, fabs(a[i]))
+                max_val, max_indices = argmax(i, fabs(a[i]), i)
                 """)
     knl = lp.tag_inames(knl, dict(i="l.0"))
+    knl = lp.add_dtypes(knl, dict(a=np.float64))
     knl = lp.realize_reduction(knl)
-    print(knl)
-    # TODO: Add functional test
+
+    ctx = ctx_factory()
+
+    with cl.CommandQueue(ctx) as queue:
+        a = np.random.rand(128)
+        out, (max_index, max_val) = knl(queue, a=a)
+
+        assert max_val == np.max(a)
+        assert max_index == np.argmax(np.abs(a))
 
 
 if __name__ == "__main__":