From c36eb5263283aba4a6564da2dce43a73bc0759e2 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 26 Mar 2018 11:22:34 -0500 Subject: [PATCH] Added the support for a reduction callable. --- loopy/kernel/creation.py | 15 +++-- loopy/kernel/function_interface.py | 26 ++++----- loopy/kernel/reduction_callable.py | 31 ++++------ loopy/library/reduction.py | 90 ++++++++++++++++++++++++------ loopy/preprocess.py | 23 ++++---- loopy/symbolic.py | 34 +++++------ loopy/target/opencl.py | 2 +- loopy/type_inference.py | 54 +++++++++++++----- 8 files changed, 178 insertions(+), 97 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 5a6423220..343c85014 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1898,7 +1898,7 @@ class FunctionScoper(IdentityMapper): from loopy.symbolic import Reduction return Reduction( - ScopedFunction(expr.operation.name), + ScopedFunction(expr.function.name), tuple(new_inames), self.rec(expr.expr), allow_simultaneous=expr.allow_simultaneous) @@ -1918,9 +1918,10 @@ class ScopedFunctionCollector(CombineMapper): def map_reduction(self, expr): from loopy.kernel.reduction_callable import CallableReduction + from loopy.kernel.function_interface import CallableOnScalar from loopy.symbolic import Reduction - callable_reduction = CallableReduction(expr.operation.name) + callable_reduction = CallableReduction(expr.function.name) # sanity checks @@ -1949,8 +1950,14 @@ class ScopedFunctionCollector(CombineMapper): elif isinstance(expr, Reduction) and callable_reduction.is_tuple_typed: raise LoopyError("got a tuple typed argument to a scalar reduction") - return frozenset([(expr.operation.name, - callable_reduction)]) + hidden_function = callable_reduction.operation.hidden_function() + if hidden_function is not None: + return frozenset([(expr.function.name, + callable_reduction), (hidden_function, + CallableOnScalar(hidden_function))]) + else: + return frozenset([(expr.function.name, + callable_reduction)]) def map_constant(self, expr): return frozenset() diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index bc5d178b1..fb80c5876 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -134,23 +134,17 @@ class InKernelCallable(ImmutableRecord): """ - fields = set(["name", "arg_id_to_dtype", "arg_id_to_descr"]) - init_arg_names = ("name", "arg_id_to_dtype", "arg_id_to_descr") + fields = set(["arg_id_to_dtype", "arg_id_to_descr"]) + init_arg_names = ("arg_id_to_dtype", "arg_id_to_descr") - def __init__(self, name, arg_id_to_dtype=None, arg_id_to_descr=None): + def __init__(self, arg_id_to_dtype=None, arg_id_to_descr=None): - # sanity checks - - if not isinstance(name, str): - raise LoopyError("name of an InKernelCallable should be a string") - - super(InKernelCallable, self).__init__(name=name, + super(InKernelCallable, self).__init__( arg_id_to_dtype=arg_id_to_dtype, arg_id_to_descr=arg_id_to_descr) def __getinitargs__(self): - return (self.name, self.arg_id_to_dtype, self.arg_id_to_descr, - self.name_in_target) + return (self.arg_id_to_dtype, self.arg_id_to_descr) def with_types(self, arg_id_to_dtype, target): """ @@ -245,10 +239,11 @@ class CallableOnScalar(InKernelCallable): def __init__(self, name, arg_id_to_dtype=None, arg_id_to_descr=None, name_in_target=None): - super(InKernelCallable, self).__init__(name=name, + super(InKernelCallable, self).__init__( arg_id_to_dtype=arg_id_to_dtype, arg_id_to_descr=arg_id_to_descr) + self.name = name self.name_in_target = name_in_target def __getinitargs__(self): @@ -265,7 +260,7 @@ class CallableOnScalar(InKernelCallable): if self.arg_id_to_dtype[id] != arg_id_to_dtype[id]: raise LoopyError("Overwriting a specialized" " function is illegal--maybe start with new instance of" - " CallableScalar?") + " CallableOnScalar?") # {{{ attempt to specialize using scalar functions present in target @@ -406,12 +401,13 @@ class CallableKernel(InKernelCallable): def __init__(self, name, subkernel, arg_id_to_dtype=None, arg_id_to_descr=None, name_in_target=None): - super(InKernelCallable, self).__init__(name=name, + super(InKernelCallable, self).__init__( arg_id_to_dtype=arg_id_to_dtype, arg_id_to_descr=arg_id_to_descr) if name_in_target is not None: subkernel = subkernel.copy(name=name_in_target) + self.name = name self.name_in_target = name_in_target self.subkernel = subkernel @@ -628,7 +624,7 @@ def register_pymbolic_calls_to_knl_callables(kernel, unique_name = next_indexed_name(unique_name) # book-keeping of the functions and names mappings for later use - if in_knl_callable.subkernel is not None: + if isinstance(in_knl_callable, CallableKernel): # for array calls the name in the target is the name of the # scoped funciton in_knl_callable = in_knl_callable.copy( diff --git a/loopy/kernel/reduction_callable.py b/loopy/kernel/reduction_callable.py index 1682f7160..1ad2acd8d 100644 --- a/loopy/kernel/reduction_callable.py +++ b/loopy/kernel/reduction_callable.py @@ -28,7 +28,7 @@ class CallableReduction(InKernelCallable): self.operation = operation - super(InKernelCallable, self).__init__(name="", + super(InKernelCallable, self).__init__( arg_id_to_dtype=arg_id_to_dtype, arg_id_to_descr=arg_id_to_descr) @@ -47,39 +47,32 @@ class CallableReduction(InKernelCallable): for id, dtype in arg_id_to_dtype.items(): # only checking for the ones which have been provided - if self.arg_id_to_dtype[id] != arg_id_to_dtype[id]: + if id in self.arg_id_to_dtype and ( + self.arg_id_to_dtype[id] != arg_id_to_dtype[id]): raise LoopyError("Overwriting a specialized" " function is illegal--maybe start with new instance of" - " CallableScalar?") - - if self.name in target.get_device_ast_builder().function_identifiers(): - new_in_knl_callable = target.get_device_ast_builder().with_types( - self, arg_id_to_dtype) - if new_in_knl_callable is None: - new_in_knl_callable = self.copy() - return new_in_knl_callable - - # did not find a scalar function and function prototype does not - # even have subkernel registered => no match found - raise LoopyError("Function %s not present within" - " the %s namespace" % (self.name, target)) + " CallableReduction?") + updated_arg_id_to_dtype = self.operation.with_types(arg_id_to_dtype, + target) + return self.copy(arg_id_to_dtype=updated_arg_id_to_dtype) def with_descrs(self, arg_id_to_descr): + # not sure what would be the reson of having this over here # This is a scalar call # need to assert that the name is in funtion indentifiers arg_id_to_descr[-1] = ValueArgDescriptor() return self.copy(arg_id_to_descr=arg_id_to_descr) - def with_iname_tag_usage(self, unusable, concurrent_shape): - - raise NotImplementedError() + def inline(self, kernel): + # Replaces the job of realize_reduction + raise NotImplementedError def is_ready_for_code_gen(self): return (self.arg_id_to_dtype is not None and self.arg_id_to_descr is not None and - self.name_in_target is not None) + self.operation is not None) # vim: foldmethod=marker diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index 5daa1528a..f4444c886 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -36,7 +36,7 @@ class ReductionOperation(object): equality-comparable. """ - def result_dtypes(self, target, *arg_dtypes): + def with_types(self, arg_id_to_dtype, target): """ :arg arg_dtypes: may be None if not known :returns: None if not known, otherwise the returned type @@ -51,6 +51,9 @@ class ReductionOperation(object): def neutral_element(self, *dtypes): raise NotImplementedError + def hidden_function(self): + return None + def __hash__(self): # Force subclasses to override raise NotImplementedError @@ -95,15 +98,22 @@ class ScalarReductionOperation(ReductionOperation): def arg_count(self): return 1 - def result_dtypes(self, kernel, arg_dtype): + def with_types(self, arg_id_to_dtype, target): + if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None: + # do not have enough info to figure out the type. + return arg_id_to_dtype.copy() + + arg_dtype = arg_id_to_dtype[0] + + updated_arg_id_to_dtype = arg_id_to_dtype.copy() if self.forced_result_type is not None: - return (self.parse_result_type( - kernel.target, self.forced_result_type),) + updated_arg_id_to_dtype[-1] = (self.parse_result_type( + target, self.forced_result_type),) + return updated_arg_id_to_dtype - if arg_dtype is None: - return None + updated_arg_id_to_dtype[-1] = arg_dtype - return (arg_dtype,) + return updated_arg_id_to_dtype def __hash__(self): return hash((type(self), self.forced_result_type)) @@ -180,7 +190,11 @@ class MaxReductionOperation(ScalarReductionOperation): return get_ge_neutral(dtype) def __call__(self, dtype, operand1, operand2): - return var("max")(operand1, operand2) + from loopy.symbolic import ScopedFunction + return ScopedFunction("max")(operand1, operand2) + + def hidden_function(self): + return "max" class MinReductionOperation(ScalarReductionOperation): @@ -188,7 +202,11 @@ class MinReductionOperation(ScalarReductionOperation): return get_le_neutral(dtype) def __call__(self, dtype, operand1, operand2): - return var("min")(operand1, operand2) + from loopy.symbolic import ScopedFunction + return ScopedFunction("min")(operand1, operand2) + + def hidden_function(self): + return "min" # {{{ base class for symbolic reduction ops @@ -233,9 +251,22 @@ class _SegmentedScalarReductionOperation(ReductionOperation): return var("make_tuple")(scalar_neutral_element, segment_flag_dtype.numpy_dtype.type(0)) - def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype): - return (self.inner_reduction.result_dtypes(kernel, scalar_dtype) - + (segment_flag_dtype,)) + def with_types(self, arg_id_to_dtype, target): + for id in range(self.arg_count): + if id not in arg_id_to_dtype or arg_id_to_dtype[id] is None: + # types of arguemnts not known => result type cannot be + # determined. + return arg_id_to_dtype.copy() + + scalar_dtype = arg_id_to_dtype[0] + segment_flag_dtype = arg_id_to_dtype[1] + + updated_arg_id_to_dtype = arg_id_to_dtype.copy() + updated_arg_id_to_dtype[-1] = self.inner_reduction.with_types( + {0: scalar_dtype}, target)[-1] + updated_arg_id_to_dtype[-2] = segment_flag_dtype + + return updated_arg_id_to_dtype def __str__(self): return "segmented(%s)" % self.which @@ -299,8 +330,22 @@ class _ArgExtremumReductionOperation(ReductionOperation): scalar_dtype.numpy_dtype.type.__name__, index_dtype.numpy_dtype.type.__name__) - def result_dtypes(self, kernel, scalar_dtype, index_dtype): - return (scalar_dtype, index_dtype) + def with_types(self, arg_id_to_dtype, target): + for id in range(self.arg_count): + if id not in arg_id_to_dtype or arg_id_to_dtype[id] is None: + # types of arguemnts not known => result type cannot be + # determined. + return self.copy(arg_id_to_dtype=arg_id_to_dtype) + + scalar_dtype = arg_id_to_dtype[0] + index_dtype = arg_id_to_dtype[1] + + updated_arg_id_to_dtype = arg_id_to_dtype.copy() + + updated_arg_id_to_dtype[-1] = scalar_dtype + updated_arg_id_to_dtype[-2] = index_dtype + + return updated_arg_id_to_dtype def neutral_element(self, scalar_dtype, index_dtype): scalar_neutral_func = ( @@ -331,12 +376,18 @@ class ArgMaxReductionOperation(_ArgExtremumReductionOperation): update_comparison = ">=" neutral_sign = -1 + def hidden_function(self): + return "max" + class ArgMinReductionOperation(_ArgExtremumReductionOperation): which = "min" update_comparison = "<=" neutral_sign = +1 + def hidden_function(self): + return "min" + def get_argext_preamble(kernel, func_id, arg_dtypes): op = func_id.reduction_op @@ -377,8 +428,8 @@ def get_argext_preamble(kernel, func_id, arg_dtypes): _REDUCTION_OPS = { "sum": SumReductionOperation, "product": ProductReductionOperation, - "max": MaxReductionOperation, - "min": MinReductionOperation, + "maximum": MaxReductionOperation, + "minimum": MinReductionOperation, "argmax": ArgMaxReductionOperation, "argmin": ArgMinReductionOperation, "segmented(sum)": SegmentedSumReductionOperation, @@ -429,6 +480,12 @@ def reduction_function_identifiers(): return set(op for op in _REDUCTION_OPS) +def reduction_function_mangler(kernel, func_id, arg_dtypes): + raise NotImplementedError("Reduction Function Mangler!") + + +''' +# KK -- we will replace this with the new interface def reduction_function_mangler(kernel, func_id, arg_dtypes): if isinstance(func_id, ArgExtOp): from loopy.target.opencl import CTarget @@ -475,6 +532,7 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes): ) return None +''' def reduction_preamble_generator(preamble_info): diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 34fe6e830..51389f4f5 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -39,7 +39,6 @@ from loopy.kernel.data import make_assignment # for the benefit of loopy.statistics, for now from loopy.type_inference import infer_unknown_types from loopy.symbolic import ScopedFunction, CombineMapper -from pymbolic.mapper import Collector from loopy.kernel.instruction import (MultiAssignmentBase, CInstruction, CallInstruction, _DataObliviousInstruction) @@ -893,7 +892,6 @@ def _insert_subdomain_into_domain_tree(kernel, domains, subdomain): # }}} - def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, automagic_scans_ok=False, force_scan=False, force_outer_iname_for_scan=None): @@ -1041,13 +1039,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, init_id = insn_id_gen( "%s_%s_init" % (insn.id, "_".join(expr.inames))) + reduction_operation = kernel.scoped_functions[ + expr.function.name].operation + init_insn = make_assignment( id=init_id, assignees=acc_vars, within_inames=outer_insn_inames - frozenset(expr.inames), within_inames_is_final=insn.within_inames_is_final, depends_on=init_insn_depends_on, - expression=expr.operation.neutral_element(*arg_dtypes), + expression=reduction_operation.neutral_element(*arg_dtypes), predicates=insn.predicates,) generated_insns.append(init_insn) @@ -1082,10 +1083,12 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, else: reduction_expr = expr.expr + reduction_operation = kernel.scoped_functions[ + expr.function.name].operation reduction_insn = make_assignment( id=update_id, assignees=acc_vars, - expression=expr.operation( + expression=reduction_operation( arg_dtypes, _strip_if_scalar(acc_vars, acc_vars), reduction_expr), @@ -1094,8 +1097,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, within_inames_is_final=insn.within_inames_is_final, predicates=insn.predicates,) - reduction_insn = scope_function_in_insn(reduction_insn, kenrel) - generated_insns.append(reduction_insn) new_insn_add_depends_on.add(reduction_insn.id) @@ -1944,6 +1945,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, kernel = lp.tag_inames(kernel, new_iname_tags) + # making changes to the scoped function that are arising + # TODO: remove unused inames... kernel = ( @@ -2381,10 +2384,6 @@ def preprocess_kernel(kernel, device=None): from loopy.kernel.creation import apply_single_writer_depencency_heuristic kernel = apply_single_writer_depencency_heuristic(kernel) - # inferring the shape and dim_tags of the arguments involved in a function - # call. - kernel = infer_arg_descr(kernel) - # Ordering restrictions: # # - realize_reduction must happen after type inference because it needs @@ -2396,6 +2395,10 @@ def preprocess_kernel(kernel, device=None): kernel = realize_reduction(kernel, unknown_types_ok=False) + # inferring the shape and dim_tags of the arguments involved in a function + # call. + kernel = infer_arg_descr(kernel) + # Ordering restriction: # add_axes_to_temporaries_for_ilp because reduction accumulators # need to be duplicated by this. diff --git a/loopy/symbolic.py b/loopy/symbolic.py index e8e39a24f..32670c1cc 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -96,7 +96,7 @@ class IdentityMapperMixin(object): new_inames.append(new_sym_iname.name) return Reduction( - expr.operation, tuple(new_inames), + expr.function, tuple(new_inames), self.rec(expr.expr, *args), allow_simultaneous=expr.allow_simultaneous) @@ -226,7 +226,7 @@ class StringifyMapper(StringifyMapperBase): return "%sreduce(%s, [%s], %s)" % ( "simul_" if expr.allow_simultaneous else "", - expr.operation, ", ".join(expr.inames), + expr.function, ", ".join(expr.inames), self.rec(expr.expr, PREC_NONE)) def map_tagged_variable(self, expr, prec): @@ -266,7 +266,7 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): if not isinstance(other, type(expr)): return self.treat_mismatch(expr, other, unis) if (expr.inames != other.inames - or type(expr.operation) != type(other.operation) # noqa + or type(expr.function) != type(other.function) # noqa ): return [] @@ -537,7 +537,7 @@ class Reduction(p.Expression): """Represents a reduction operation on :attr:`exprs` across :attr:`inames`. - ..attribute:: operation + ..attribute:: function an instance of :class:`pymbolic.primitives.Variable` which indicates the reduction callable that the reduction would point to in the dict @@ -562,10 +562,10 @@ class Reduction(p.Expression): in precisely one reduction, to avoid mis-nesting errors. """ - init_arg_names = ("operation", "inames", "expr", "allow_simultaneous") + init_arg_names = ("function", "inames", "expr", "allow_simultaneous") - def __init__(self, operation, inames, expr, allow_simultaneous=False): - assert isinstance(operation, p.Variable) + def __init__(self, function, inames, expr, allow_simultaneous=False): + assert isinstance(function, p.Variable) if isinstance(inames, str): inames = tuple(iname.strip() for iname in inames.split(",")) @@ -610,20 +610,20 @@ class Reduction(p.Expression): raise LoopyError("got a tuple typed argument to a scalar reduction") """ - self.operation = operation + self.function = function self.inames = inames self.expr = expr self.allow_simultaneous = allow_simultaneous def __getinitargs__(self): - return (self.operation, self.inames, self.expr, self.allow_simultaneous) + return (self.funciton, self.inames, self.expr, self.allow_simultaneous) def get_hash(self): - return hash((self.__class__, self.operation, self.inames, self.expr)) + return hash((self.__class__, self.function, self.inames, self.expr)) def is_equal(self, other): return (other.__class__ == self.__class__ - and other.operation == self.operation + and other.function == self.function and other.inames == self.inames and other.expr == self.expr) @@ -1146,10 +1146,10 @@ class FunctionToPrimitiveMapper(IdentityMapper): turns those into the actual pymbolic primitives used for that. """ - def _parse_reduction(self, operation, inames, red_exprs, + def _parse_reduction(self, function, inames, red_exprs, allow_simultaneous=False): - assert isinstance(operation, str) - operation = p.Variable(operation) + assert isinstance(function, str) + function = p.Variable(function) if isinstance(inames, p.Variable): inames = (inames,) @@ -1168,7 +1168,7 @@ class FunctionToPrimitiveMapper(IdentityMapper): if len(red_exprs) == 1: red_exprs = red_exprs[0] - return Reduction(operation, tuple(processed_inames), red_exprs, + return Reduction(function, tuple(processed_inames), red_exprs, allow_simultaneous=allow_simultaneous) def map_call(self, expr): @@ -1194,10 +1194,10 @@ class FunctionToPrimitiveMapper(IdentityMapper): elif name in set(["reduce, simul_reduce"]): if len(expr.parameters) >= 3: - operation, inames = expr.parameters[:2] + function, inames = expr.parameters[:2] red_exprs = expr.parameters[2:] - return self._parse_reduction(str(operation), inames, + return self._parse_reduction(str(function), inames, tuple(self.rec(red_expr) for red_expr in red_exprs), allow_simultaneous=(name == "simul_reduce")) else: diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 7aec34a22..7ffd91309 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -255,7 +255,7 @@ def opencl_with_types(in_knl_callable, arg_id_to_dtype): dtype = np.find_common_type( [], [dtype.numpy_dtype for id, dtype in - arg_id_to_dtype.values() if id >= 0]) + arg_id_to_dtype.items() if id >= 0]) if dtype.kind == "i": dtype = NumpyType(dtype) diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 11113538e..8df9773a9 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -396,7 +396,10 @@ class TypeInferenceMapper(CombineMapper): from loopy.symbolic import Reduction from pymbolic.primitives import Call - if not return_tuple and expr.is_tuple_typed: + reduction_callable = self.scoped_functions[ + expr.function.name] + + if not return_tuple and reduction_callable.is_tuple_typed: raise LoopyError("reductions with more or fewer than one " "return value may only be used in direct " "assignments") @@ -416,12 +419,23 @@ class TypeInferenceMapper(CombineMapper): else: rec_results = self.rec(expr.expr) - if return_tuple: - return [expr.operation.result_dtypes(self.kernel, *rec_result) - for rec_result in rec_results] - else: - return [expr.operation.result_dtypes(self.kernel, rec_result)[0] - for rec_result in rec_results] + arg_id_to_dtype = dict(enumerate(rec_results)) + + in_knl_callable = ( + self.scoped_functions[expr.function.name].with_types( + arg_id_to_dtype, self.kernel.target)) + + # storing the type specialized function so that it can be used for + # later use + self.specialized_functions[expr] = in_knl_callable + + new_arg_id_to_dtype = in_knl_callable.arg_id_to_dtype + + # collecting result dtypes in order of the assignees + if -1 in new_arg_id_to_dtype and new_arg_id_to_dtype[-1] is not None: + return [new_arg_id_to_dtype[-1]] + + return [] def map_sub_array_ref(self, expr): return self.rec(expr.get_begin_subscript()) @@ -691,8 +705,9 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression( kernel, expr, unknown_types_ok): type_inf_mapper = TypeInferenceMapper(kernel) import loopy as lp + callable_reduction = kernel.scoped_functions[expr.function.name] - if expr.is_tuple_typed: + if callable_reduction.is_tuple_typed: arg_dtypes_result = type_inf_mapper( expr, return_tuple=True, return_dtype_set=True) @@ -700,7 +715,7 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression( arg_dtypes = arg_dtypes_result[0] else: if unknown_types_ok: - arg_dtypes = [lp.auto] * expr.operation.arg_count + arg_dtypes = [lp.auto] * callable_reduction.operation.arg_count else: raise LoopyError("failed to determine types of accumulators for " "reduction '%s'" % expr) @@ -714,13 +729,22 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression( raise LoopyError("failed to determine type of accumulator for " "reduction '%s'" % expr) - reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes) - reduction_dtypes = tuple( - dt.with_target(kernel.target) - if dt is not lp.auto else dt - for dt in reduction_dtypes) + # TODODODODODODODODODO + + new_arg_id_to_dtype = callable_reduction.with_types( + dict(enumerate(arg_dtypes)), kernel.target).arg_id_to_dtype + + num_result = len([id for id in new_arg_id_to_dtype if id < 0]) + reduction_dtypes = [] + + for id in range(num_result): + dt = new_arg_id_to_dtype[-id-1] + if dt is not lp.auto: + reduction_dtypes.append(dt.with_target(kernel.target)) + else: + reduction_dtypes.append(dt) - return tuple(arg_dtypes), reduction_dtypes + return tuple(arg_dtypes), tuple(reduction_dtypes) # }}} -- GitLab