diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py index b4b1d5e48b6e686683b9ab6805700b1ea50769f5..9b62dd0860f77cbec948f6aa4df5e66a91bfb6fd 100644 --- a/pyopencl/reduction.py +++ b/pyopencl/reduction.py @@ -216,7 +216,7 @@ def get_reduction_source( -def get_reduction_kernel( +def get_reduction_kernel(stage, ctx, out_type, out_type_size, neutral, reduce_expr, map_expr=None, arguments=None, name="reduce_kernel", preamble="", @@ -224,7 +224,7 @@ def get_reduction_kernel( if map_expr is None: map_expr = "in[i]" - if arguments is None: + if stage == 2: arguments = "__global const %s *in" % out_type inf = get_reduction_source( @@ -261,15 +261,15 @@ class ReductionKernel: dtype_out = self.dtype_out = np.dtype(dtype_out) - self.stage_1_inf = get_reduction_kernel(ctx, + self.stage_1_inf = get_reduction_kernel(1, ctx, dtype_to_ctype(dtype_out), dtype_out.itemsize, neutral, reduce_expr, map_expr, arguments, name=name+"_stage1", options=options, preamble=preamble) # stage 2 has only one input and no map expression - self.stage_2_inf = get_reduction_kernel(ctx, + self.stage_2_inf = get_reduction_kernel(2, ctx, dtype_to_ctype(dtype_out), dtype_out.itemsize, - neutral, reduce_expr, + neutral, reduce_expr, arguments=arguments, name=name+"_stage2", options=options, preamble=preamble) from pytools import any