diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index f9648bde7dc4d685ca9daf63ecf15b69496c8651..962b316817f8680fbe3fc7826aed771d1b1d1eec 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -123,7 +123,7 @@ class ScalarReductionOperation(ReductionOperation):
 
 class SumReductionOperation(ScalarReductionOperation):
     def neutral_element(self, dtype):
-        return 0
+        return dtype.numpy_dtype.type(0)
 
     def __call__(self, dtype, operand1, operand2):
         return operand1 + operand2
@@ -131,7 +131,7 @@ class SumReductionOperation(ScalarReductionOperation):
 
 class ProductReductionOperation(ScalarReductionOperation):
     def neutral_element(self, dtype):
-        return 1
+        return dtype.numpy_dtype.type(1)
 
     def __call__(self, dtype, operand1, operand2):
         return operand1 * operand2
@@ -189,8 +189,26 @@ class MinReductionOperation(ScalarReductionOperation):
         return var("min")(operand1, operand2)
 
 
+# {{{ base class for symbolic reduction ops
+
+class ReductionOpFunction(FunctionIdentifier):
+    init_arg_names = ("reduction_op",)
+
+    def __init__(self, reduction_op):
+        self.reduction_op = reduction_op
+
+    def __getinitargs__(self):
+        return (self.reduction_op,)
+
+# }}}
+
+
 # {{{ segmented reduction
 
+class SegmentedOp(ReductionOpFunction):
+    pass
+
+
 class _SegmentedScalarReductionOperation(ReductionOperation):
     def __init__(self, **kwargs):
         self.inner_reduction = self.base_reduction_class(**kwargs)
@@ -205,7 +223,9 @@ class _SegmentedScalarReductionOperation(ReductionOperation):
                 segment_flag_dtype.numpy_dtype.type.__name__)
 
     def neutral_element(self, scalar_dtype, segment_flag_dtype):
-        return SegmentedFunction(self, (scalar_dtype, segment_flag_dtype), "init")()
+        scalar_neutral_element = self.inner_reduction.neutral_element(scalar_dtype)
+        return var("make_tuple")(scalar_neutral_element,
+                segment_flag_dtype.numpy_dtype.type(0))
 
     def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype):
         return (self.inner_reduction.result_dtypes(kernel, scalar_dtype)
@@ -221,7 +241,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation):
         return type(self) == type(other)
 
     def __call__(self, dtypes, operand1, operand2):
-        return SegmentedFunction(self, dtypes, "update")(*(operand1 + operand2))
+        return SegmentedOp(self)(*(operand1 + operand2))
 
 
 class SegmentedSumReductionOperation(_SegmentedScalarReductionOperation):
@@ -236,45 +256,13 @@ class SegmentedProductReductionOperation(_SegmentedScalarReductionOperation):
     which = "product"
 
 
-class SegmentedFunction(FunctionIdentifier):
-    init_arg_names = ("reduction_op", "dtypes", "name")
-
-    def __init__(self, reduction_op, dtypes, name):
-        """
-        :arg dtypes: A :class:`tuple` of `(scalar_dtype, segment_flag_dtype)`
-        """
-        self.reduction_op = reduction_op
-        self.dtypes = dtypes
-        self.name = name
-
-    @property
-    def scalar_dtype(self):
-        return self.dtypes[0]
-
-    @property
-    def segment_flag_dtype(self):
-        return self.dtypes[1]
-
-    def __getinitargs__(self):
-        return (self.reduction_op, self.dtypes, self.name)
-
-
-def get_segmented_function_preamble(kernel, func_id):
+def get_segmented_function_preamble(kernel, func_id, arg_dtypes):
     op = func_id.reduction_op
-    prefix = op.prefix(func_id.scalar_dtype, func_id.segment_flag_dtype)
-
-    from pymbolic.mapper.c_code import CCodeMapper
-
-    c_code_mapper = CCodeMapper()
+    scalar_dtype, segment_flag_dtype = arg_dtypes
+    prefix = op.prefix(scalar_dtype, segment_flag_dtype)
 
     return (prefix, """
-    inline %(scalar_t)s %(prefix)s_init(%(segment_flag_t)s *segment_flag_out)
-    {
-        *segment_flag_out = 0;
-        return %(neutral)s;
-    }
-
-    inline %(scalar_t)s %(prefix)s_update(
+    inline %(scalar_t)s %(prefix)s_op(
         %(scalar_t)s op1, %(segment_flag_t)s segment_flag1,
         %(scalar_t)s op2, %(segment_flag_t)s segment_flag2,
         %(segment_flag_t)s *segment_flag_out)
@@ -283,32 +271,36 @@ def get_segmented_function_preamble(kernel, func_id):
         return segment_flag2 ? op2 : %(combined)s;
     }
     """ % dict(
-            scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype),
+            scalar_t=kernel.target.dtype_to_typename(scalar_dtype),
             prefix=prefix,
-            segment_flag_t=kernel.target.dtype_to_typename(
-                    func_id.segment_flag_dtype),
-            neutral=c_code_mapper(
-                    op.inner_reduction.neutral_element(func_id.scalar_dtype)),
+            segment_flag_t=kernel.target.dtype_to_typename(segment_flag_dtype),
             combined=op.op % ("op1", "op2"),
             ))
 
-
 # }}}
 
 
 # {{{ argmin/argmax
 
+class ArgExtOp(ReductionOpFunction):
+    pass
+
+
 class _ArgExtremumReductionOperation(ReductionOperation):
     def prefix(self, scalar_dtype, index_dtype):
         return "loopy_arg%s_%s_%s" % (self.which,
-                index_dtype.numpy_dtype.type.__name__,
-                scalar_dtype.numpy_dtype.type.__name__)
+                scalar_dtype.numpy_dtype.type.__name__,
+                index_dtype.numpy_dtype.type.__name__)
 
     def result_dtypes(self, kernel, scalar_dtype, index_dtype):
         return (scalar_dtype, index_dtype)
 
     def neutral_element(self, scalar_dtype, index_dtype):
-        return ArgExtFunction(self, (scalar_dtype, index_dtype), "init")()
+        scalar_neutral_func = (
+                get_ge_neutral if self.neutral_sign < 0 else get_le_neutral)
+        scalar_neutral_element = scalar_neutral_func(scalar_dtype)
+        return var("make_tuple")(scalar_neutral_element,
+                index_dtype.numpy_dtype.type(-1))
 
     def __str__(self):
         return self.which
@@ -324,7 +316,7 @@ class _ArgExtremumReductionOperation(ReductionOperation):
         return 2
 
     def __call__(self, dtypes, operand1, operand2):
-        return ArgExtFunction(self, dtypes, "update")(*(operand1 + operand2))
+        return ArgExtOp(self)(*(operand1 + operand2))
 
 
 class ArgMaxReductionOperation(_ArgExtremumReductionOperation):
@@ -339,44 +331,15 @@ class ArgMinReductionOperation(_ArgExtremumReductionOperation):
     neutral_sign = +1
 
 
-class ArgExtFunction(FunctionIdentifier):
-    init_arg_names = ("reduction_op", "dtypes", "name")
-
-    def __init__(self, reduction_op, dtypes, name):
-        self.reduction_op = reduction_op
-        self.dtypes = dtypes
-        self.name = name
-
-    @property
-    def scalar_dtype(self):
-        return self.dtypes[0]
-
-    @property
-    def index_dtype(self):
-        return self.dtypes[1]
-
-    def __getinitargs__(self):
-        return (self.reduction_op, self.dtypes, self.name)
-
-
-def get_argext_preamble(kernel, func_id):
+def get_argext_preamble(kernel, func_id, arg_dtypes):
     op = func_id.reduction_op
-    prefix = op.prefix(func_id.scalar_dtype, func_id.index_dtype)
-
-    from pymbolic.mapper.c_code import CCodeMapper
+    scalar_dtype = arg_dtypes[0]
+    index_dtype = arg_dtypes[1]
 
-    c_code_mapper = CCodeMapper()
-
-    neutral = get_ge_neutral if op.neutral_sign < 0 else get_le_neutral
+    prefix = op.prefix(scalar_dtype, index_dtype)
 
     return (prefix, """
-    inline %(scalar_t)s %(prefix)s_init(%(index_t)s *index_out)
-    {
-        *index_out = INT_MIN;
-        return %(neutral)s;
-    }
-
-    inline %(scalar_t)s %(prefix)s_update(
+    inline %(scalar_t)s %(prefix)s_op(
         %(scalar_t)s op1, %(index_t)s index1,
         %(scalar_t)s op2, %(index_t)s index2,
         %(index_t)s *index_out)
@@ -393,10 +356,9 @@ def get_argext_preamble(kernel, func_id):
         }
     }
     """ % dict(
-            scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype),
+            scalar_t=kernel.target.dtype_to_typename(scalar_dtype),
             prefix=prefix,
-            index_t=kernel.target.dtype_to_typename(func_id.index_dtype),
-            neutral=c_code_mapper(neutral(func_id.scalar_dtype)),
+            index_t=kernel.target.dtype_to_typename(index_dtype),
             comp=op.update_comparison,
             ))
 
@@ -454,76 +416,47 @@ def parse_reduction_op(name):
 
 
 def reduction_function_mangler(kernel, func_id, arg_dtypes):
-    if isinstance(func_id, ArgExtFunction) and func_id.name == "init":
+    if isinstance(func_id, ArgExtOp):
         from loopy.target.opencl import CTarget
         if not isinstance(kernel.target, CTarget):
             raise LoopyError("%s: only C-like targets supported for now" % func_id)
 
         op = func_id.reduction_op
+        scalar_dtype = arg_dtypes[0]
+        index_dtype = arg_dtypes[1]
 
         from loopy.kernel.data import CallMangleInfo
         return CallMangleInfo(
-                target_name="%s_init" % op.prefix(
-                    func_id.scalar_dtype, func_id.index_dtype),
+                target_name="%s_op" % op.prefix(
+                    scalar_dtype, index_dtype),
                 result_dtypes=op.result_dtypes(
-                    kernel, func_id.scalar_dtype, func_id.index_dtype),
-                arg_dtypes=(),
-                )
-
-    elif isinstance(func_id, ArgExtFunction) and func_id.name == "update":
-        from loopy.target.opencl import CTarget
-        if not isinstance(kernel.target, CTarget):
-            raise LoopyError("%s: only C-like targets supported for now" % func_id)
-
-        op = func_id.reduction_op
-
-        from loopy.kernel.data import CallMangleInfo
-        return CallMangleInfo(
-                target_name="%s_update" % op.prefix(
-                    func_id.scalar_dtype, func_id.index_dtype),
-                result_dtypes=op.result_dtypes(
-                    kernel, func_id.scalar_dtype, func_id.index_dtype),
+                    kernel, scalar_dtype, index_dtype),
                 arg_dtypes=(
-                    func_id.scalar_dtype,
-                    kernel.index_dtype,
-                    func_id.scalar_dtype,
-                    kernel.index_dtype),
-                )
-
-    elif isinstance(func_id, SegmentedFunction) and func_id.name == "init":
-        from loopy.target.opencl import CTarget
-        if not isinstance(kernel.target, CTarget):
-            raise LoopyError("%s: only C-like targets supported for now" % func_id)
-
-        op = func_id.reduction_op
-
-        from loopy.kernel.data import CallMangleInfo
-        return CallMangleInfo(
-                target_name="%s_init" % op.prefix(
-                    func_id.scalar_dtype, func_id.segment_flag_dtype),
-                result_dtypes=op.result_dtypes(
-                    kernel, func_id.scalar_dtype, func_id.segment_flag_dtype),
-                arg_dtypes=(),
+                    scalar_dtype,
+                    index_dtype,
+                    scalar_dtype,
+                    index_dtype),
                 )
 
-    elif isinstance(func_id, SegmentedFunction) and func_id.name == "update":
+    elif isinstance(func_id, SegmentedOp):
         from loopy.target.opencl import CTarget
         if not isinstance(kernel.target, CTarget):
             raise LoopyError("%s: only C-like targets supported for now" % func_id)
 
         op = func_id.reduction_op
+        scalar_dtype, segment_flag_dtype = arg_dtypes
 
         from loopy.kernel.data import CallMangleInfo
         return CallMangleInfo(
-                target_name="%s_update" % op.prefix(
-                    func_id.scalar_dtype, func_id.segment_flag_dtype),
+                target_name="%s_op" % op.prefix(
+                    scalar_dtype, segment_flag_dtype),
                 result_dtypes=op.result_dtypes(
-                    kernel, func_id.scalar_dtype, func_id.segment_flag_dtype),
+                    kernel, scalar_dtype, segment_flag_dtype),
                 arg_dtypes=(
-                    func_id.scalar_dtype,
-                    func_id.segment_flag_dtype,
-                    func_id.scalar_dtype,
-                    func_id.segment_flag_dtype),
+                    scalar_dtype,
+                    segment_flag_dtype,
+                    scalar_dtype,
+                    segment_flag_dtype),
                 )
 
     return None
@@ -533,16 +466,18 @@ def reduction_preamble_generator(preamble_info):
     from loopy.target.opencl import OpenCLTarget
 
     for func in preamble_info.seen_functions:
-        if isinstance(func.name, ArgExtFunction):
+        if isinstance(func.name, ArgExtOp):
             if not isinstance(preamble_info.kernel.target, OpenCLTarget):
                 raise LoopyError("only OpenCL supported for now")
 
-            yield get_argext_preamble(preamble_info.kernel, func.name)
+            yield get_argext_preamble(preamble_info.kernel, func.name,
+                    func.arg_dtypes)
 
-        elif isinstance(func.name, SegmentedFunction):
+        elif isinstance(func.name, SegmentedOp):
             if not isinstance(preamble_info.kernel.target, OpenCLTarget):
                 raise LoopyError("only OpenCL supported for now")
 
-            yield get_segmented_function_preamble(preamble_info.kernel, func.name)
+            yield get_segmented_function_preamble(preamble_info.kernel, func.name,
+                    func.arg_dtypes)
 
 # vim: fdm=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 21db62610f3a3160bcc3069c3e480e85cc4712f8..ad5fd72b65b9946156d1067aabdb4ff510d6ec63 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2335,6 +2335,47 @@ def test_kernel_var_name_generator():
     assert vng("b") != "b"
 
 
+def test_complex_argmin(ctx_factory):
+    cl_ctx = ctx_factory()
+    knl = lp.make_kernel(
+            "{[ictr,itgt,idim]: "
+            "0<=itgt<ntargets "
+            "and 0<=ictr<ncenters "
+            "and 0<=idim<ambient_dim}",
+
+            """
+            for itgt
+                for ictr
+                    <> dist_sq = sum(idim,
+                            (tgt[idim,itgt] - center[idim,ictr])**2)
+                    <> in_disk = dist_sq < (radius[ictr]*1.05)**2
+                    <> matches = (
+                            (in_disk
+                                and qbx_forced_limit == 0)
+                            or (in_disk
+                                    and qbx_forced_limit != 0
+                                    and qbx_forced_limit * center_side[ictr] > 0)
+                            )
+
+                    <> post_dist_sq = if(matches, dist_sq, HUGE)
+                end
+                <> min_dist_sq, <> min_ictr = argmin(ictr, ictr, post_dist_sq)
+
+                tgt_to_qbx_center[itgt] = if(min_dist_sq < HUGE, min_ictr, -1)
+            end
+            """)
+
+    knl = lp.fix_parameters(knl, ambient_dim=2)
+    knl = lp.add_and_infer_dtypes(knl, {
+            "tgt,center,radius,HUGE": np.float32, 
+            "center_side,qbx_forced_limit": np.int32,
+            })
+
+    lp.auto_test_vs_ref(knl, cl_ctx, knl, parameters={
+            "HUGE": 1e20, "ncenters": 200, "ntargets": 300,
+            "qbx_forced_limit": 1})
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])