From 2c77711ad436ca29000f0c9791948787bb6f4b34 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 11 May 2016 21:56:28 -0500 Subject: [PATCH] Fixes to multivalued functions --- loopy/library/reduction.py | 24 +++++++++++++++++++----- loopy/preprocess.py | 4 ++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index 1540222b2..8a38eebd5 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -84,8 +84,8 @@ class ScalarReductionOperation(ReductionOperation): def result_dtypes(self, kernel, arg_dtype, inames): if self.forced_result_type is not None: - return self.parse_result_type( - kernel.target, self.forced_result_type) + return (self.parse_result_type( + kernel.target, self.forced_result_type),) return (arg_dtype,) @@ -289,7 +289,7 @@ def parse_reduction_op(name): def reduction_function_mangler(kernel, func_id, arg_dtypes): - if isinstance(func_id, ArgExtFunction): + if isinstance(func_id, ArgExtFunction) and func_id.name == "init": from loopy.target.opencl import OpenCLTarget if not isinstance(kernel.target, OpenCLTarget): raise LoopyError("only OpenCL supported for now") @@ -298,8 +298,22 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes): from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_%s" % ( - op.prefix(func_id.scalar_dtype), func_id.name), + target_name="%s_init" % op.prefix(func_id.scalar_dtype), + result_dtypes=op.result_dtypes( + kernel, func_id.scalar_dtype, func_id.inames), + arg_dtypes=(), + ) + + elif isinstance(func_id, ArgExtFunction) and func_id.name == "update": + from loopy.target.opencl import OpenCLTarget + if not isinstance(kernel.target, OpenCLTarget): + raise LoopyError("only OpenCL supported for now") + + op = func_id.reduction_op + + from loopy.kernel.data import CallMangleInfo + return CallMangleInfo( + target_name="%s_update" % op.prefix(func_id.scalar_dtype), result_dtypes=op.result_dtypes( kernel, func_id.scalar_dtype, func_id.inames), arg_dtypes=( diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 51d588ef5..e1ee119d5 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -469,8 +469,12 @@ def realize_reduction(kernel, insn_id_filter=None): raise LoopyError("failed to determine type of accumulator for " "reduction '%s'" % expr) + arg_dtype = arg_dtype.with_target(kernel.target) + reduction_dtypes = expr.operation.result_dtypes( kernel, arg_dtype, expr.inames) + reduction_dtypes = tuple( + dt.with_target(kernel.target) for dt in reduction_dtypes) ncomp = len(reduction_dtypes) -- GitLab