diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py
index b4b1d5e48b6e686683b9ab6805700b1ea50769f5..9b62dd0860f77cbec948f6aa4df5e66a91bfb6fd 100644
--- a/pyopencl/reduction.py
+++ b/pyopencl/reduction.py
@@ -216,7 +216,7 @@ def  get_reduction_source(
 
 
 
-def get_reduction_kernel(
+def get_reduction_kernel(stage,
          ctx, out_type, out_type_size,
          neutral, reduce_expr, map_expr=None, arguments=None,
          name="reduce_kernel", preamble="",
@@ -224,7 +224,7 @@ def get_reduction_kernel(
     if map_expr is None:
         map_expr = "in[i]"
 
-    if arguments is None:
+    if stage == 2:
         arguments = "__global const %s *in" % out_type
 
     inf = get_reduction_source(
@@ -261,15 +261,15 @@ class ReductionKernel:
 
         dtype_out = self.dtype_out = np.dtype(dtype_out)
 
-        self.stage_1_inf = get_reduction_kernel(ctx,
+        self.stage_1_inf = get_reduction_kernel(1, ctx,
                 dtype_to_ctype(dtype_out), dtype_out.itemsize,
                 neutral, reduce_expr, map_expr, arguments,
                 name=name+"_stage1", options=options, preamble=preamble)
 
         # stage 2 has only one input and no map expression
-        self.stage_2_inf = get_reduction_kernel(ctx,
+        self.stage_2_inf = get_reduction_kernel(2, ctx,
                 dtype_to_ctype(dtype_out), dtype_out.itemsize,
-                neutral, reduce_expr,
+                neutral, reduce_expr, arguments=arguments,
                 name=name+"_stage2", options=options, preamble=preamble)
 
         from pytools import any