From 2c77711ad436ca29000f0c9791948787bb6f4b34 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 11 May 2016 21:56:28 -0500
Subject: [PATCH] Fixes to multivalued functions

---
 loopy/library/reduction.py | 24 +++++++++++++++++++-----
 loopy/preprocess.py        |  4 ++++
 2 files changed, 23 insertions(+), 5 deletions(-)

diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index 1540222b2..8a38eebd5 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -84,8 +84,8 @@ class ScalarReductionOperation(ReductionOperation):
 
     def result_dtypes(self, kernel, arg_dtype, inames):
         if self.forced_result_type is not None:
-            return self.parse_result_type(
-                    kernel.target, self.forced_result_type)
+            return (self.parse_result_type(
+                    kernel.target, self.forced_result_type),)
 
         return (arg_dtype,)
 
@@ -289,7 +289,7 @@ def parse_reduction_op(name):
 
 
 def reduction_function_mangler(kernel, func_id, arg_dtypes):
-    if isinstance(func_id, ArgExtFunction):
+    if isinstance(func_id, ArgExtFunction) and func_id.name == "init":
         from loopy.target.opencl import OpenCLTarget
         if not isinstance(kernel.target, OpenCLTarget):
             raise LoopyError("only OpenCL supported for now")
@@ -298,8 +298,22 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes):
 
         from loopy.kernel.data import CallMangleInfo
         return CallMangleInfo(
-                target_name="%s_%s" % (
-                    op.prefix(func_id.scalar_dtype), func_id.name),
+                target_name="%s_init" % op.prefix(func_id.scalar_dtype),
+                result_dtypes=op.result_dtypes(
+                    kernel, func_id.scalar_dtype, func_id.inames),
+                arg_dtypes=(),
+                )
+
+    elif isinstance(func_id, ArgExtFunction) and func_id.name == "update":
+        from loopy.target.opencl import OpenCLTarget
+        if not isinstance(kernel.target, OpenCLTarget):
+            raise LoopyError("only OpenCL supported for now")
+
+        op = func_id.reduction_op
+
+        from loopy.kernel.data import CallMangleInfo
+        return CallMangleInfo(
+                target_name="%s_update" % op.prefix(func_id.scalar_dtype),
                 result_dtypes=op.result_dtypes(
                     kernel, func_id.scalar_dtype, func_id.inames),
                 arg_dtypes=(
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 51d588ef5..e1ee119d5 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -469,8 +469,12 @@ def realize_reduction(kernel, insn_id_filter=None):
             raise LoopyError("failed to determine type of accumulator for "
                     "reduction '%s'" % expr)
 
+        arg_dtype = arg_dtype.with_target(kernel.target)
+
         reduction_dtypes = expr.operation.result_dtypes(
                     kernel, arg_dtype, expr.inames)
+        reduction_dtypes = tuple(
+                dt.with_target(kernel.target) for dt in reduction_dtypes)
 
         ncomp = len(reduction_dtypes)
 
-- 
GitLab