Skip to content
Snippets Groups Projects
Commit 2c77711a authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fixes to multivalued functions

parent 27281581
No related branches found
No related tags found
No related merge requests found
......@@ -84,8 +84,8 @@ class ScalarReductionOperation(ReductionOperation):
def result_dtypes(self, kernel, arg_dtype, inames):
if self.forced_result_type is not None:
return self.parse_result_type(
kernel.target, self.forced_result_type)
return (self.parse_result_type(
kernel.target, self.forced_result_type),)
return (arg_dtype,)
......@@ -289,7 +289,7 @@ def parse_reduction_op(name):
def reduction_function_mangler(kernel, func_id, arg_dtypes):
if isinstance(func_id, ArgExtFunction):
if isinstance(func_id, ArgExtFunction) and func_id.name == "init":
from loopy.target.opencl import OpenCLTarget
if not isinstance(kernel.target, OpenCLTarget):
raise LoopyError("only OpenCL supported for now")
......@@ -298,8 +298,22 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes):
from loopy.kernel.data import CallMangleInfo
return CallMangleInfo(
target_name="%s_%s" % (
op.prefix(func_id.scalar_dtype), func_id.name),
target_name="%s_init" % op.prefix(func_id.scalar_dtype),
result_dtypes=op.result_dtypes(
kernel, func_id.scalar_dtype, func_id.inames),
arg_dtypes=(),
)
elif isinstance(func_id, ArgExtFunction) and func_id.name == "update":
from loopy.target.opencl import OpenCLTarget
if not isinstance(kernel.target, OpenCLTarget):
raise LoopyError("only OpenCL supported for now")
op = func_id.reduction_op
from loopy.kernel.data import CallMangleInfo
return CallMangleInfo(
target_name="%s_update" % op.prefix(func_id.scalar_dtype),
result_dtypes=op.result_dtypes(
kernel, func_id.scalar_dtype, func_id.inames),
arg_dtypes=(
......
......@@ -469,8 +469,12 @@ def realize_reduction(kernel, insn_id_filter=None):
raise LoopyError("failed to determine type of accumulator for "
"reduction '%s'" % expr)
arg_dtype = arg_dtype.with_target(kernel.target)
reduction_dtypes = expr.operation.result_dtypes(
kernel, arg_dtype, expr.inames)
reduction_dtypes = tuple(
dt.with_target(kernel.target) for dt in reduction_dtypes)
ncomp = len(reduction_dtypes)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment