diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index 1540222b24526b041cc5081bfba5691f454456ae..8a38eebd55b003c624b386bcdf296d2b97e2c97c 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -84,8 +84,8 @@ class ScalarReductionOperation(ReductionOperation): def result_dtypes(self, kernel, arg_dtype, inames): if self.forced_result_type is not None: - return self.parse_result_type( - kernel.target, self.forced_result_type) + return (self.parse_result_type( + kernel.target, self.forced_result_type),) return (arg_dtype,) @@ -289,7 +289,7 @@ def parse_reduction_op(name): def reduction_function_mangler(kernel, func_id, arg_dtypes): - if isinstance(func_id, ArgExtFunction): + if isinstance(func_id, ArgExtFunction) and func_id.name == "init": from loopy.target.opencl import OpenCLTarget if not isinstance(kernel.target, OpenCLTarget): raise LoopyError("only OpenCL supported for now") @@ -298,8 +298,22 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes): from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_%s" % ( - op.prefix(func_id.scalar_dtype), func_id.name), + target_name="%s_init" % op.prefix(func_id.scalar_dtype), + result_dtypes=op.result_dtypes( + kernel, func_id.scalar_dtype, func_id.inames), + arg_dtypes=(), + ) + + elif isinstance(func_id, ArgExtFunction) and func_id.name == "update": + from loopy.target.opencl import OpenCLTarget + if not isinstance(kernel.target, OpenCLTarget): + raise LoopyError("only OpenCL supported for now") + + op = func_id.reduction_op + + from loopy.kernel.data import CallMangleInfo + return CallMangleInfo( + target_name="%s_update" % op.prefix(func_id.scalar_dtype), result_dtypes=op.result_dtypes( kernel, func_id.scalar_dtype, func_id.inames), arg_dtypes=( diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 51d588ef59b0e99f1ab6504deebd17fe828ff96f..e1ee119d5c9afb985384cffc17c9946c49a5b734 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -469,8 +469,12 @@ def realize_reduction(kernel, insn_id_filter=None): raise LoopyError("failed to determine type of accumulator for " "reduction '%s'" % expr) + arg_dtype = arg_dtype.with_target(kernel.target) + reduction_dtypes = expr.operation.result_dtypes( kernel, arg_dtype, expr.inames) + reduction_dtypes = tuple( + dt.with_target(kernel.target) for dt in reduction_dtypes) ncomp = len(reduction_dtypes)