From 969ce8ddb707ed778f535a72a96ee7b48772f95c Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Fri, 7 Apr 2017 15:39:28 -0500 Subject: [PATCH] Add support for tuple-typed reductions. * Eliminates inames arguments to reduction functions (closes #32). * Changes the argmax function to accept two arguments. See also: #62 --- loopy/kernel/instruction.py | 6 +- loopy/library/reduction.py | 246 +++++++++++++++++++++++++++++++----- loopy/preprocess.py | 92 +++++++------- loopy/symbolic.py | 110 +++++++++++----- loopy/transform/data.py | 22 +++- loopy/transform/iname.py | 29 ++++- loopy/type_inference.py | 62 +++++++-- test/test_reduction.py | 18 ++- 8 files changed, 443 insertions(+), 142 deletions(-) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index fdd8f1d37..752e3e4da 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -658,7 +658,11 @@ class MultiAssignmentBase(InstructionBase): @memoize_method def reduction_inames(self): def map_reduction(expr, rec): - rec(expr.expr) + if expr.is_plain_tuple: + for sub_expr in expr.exprs: + rec(sub_expr) + else: + rec(expr.exprs) for iname in expr.inames: result.add(iname) diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index d24b61c12..e3c7e6099 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -36,15 +36,19 @@ class ReductionOperation(object): equality-comparable. """ - def result_dtypes(self, target, arg_dtype, inames): + def result_dtypes(self, target, *arg_dtypes): """ - :arg arg_dtype: may be None if not known + :arg arg_dtypes: may be None if not known :returns: None if not known, otherwise the returned type """ raise NotImplementedError - def neutral_element(self, dtype, inames): + @property + def arg_count(self): + raise NotImplementedError + + def neutral_element(self, *dtypes): raise NotImplementedError def __hash__(self): @@ -55,12 +59,16 @@ class ReductionOperation(object): # Force subclasses to override raise NotImplementedError - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): raise NotImplementedError def __ne__(self, other): return not self.__eq__(other) + @property + def is_segmented(self): + raise NotImplementedError + @staticmethod def parse_result_type(target, op_type): try: @@ -87,7 +95,11 @@ class ScalarReductionOperation(ReductionOperation): """ self.forced_result_type = forced_result_type - def result_dtypes(self, kernel, arg_dtype, inames): + @property + def arg_count(self): + return 1 + + def result_dtypes(self, kernel, arg_dtype): if self.forced_result_type is not None: return (self.parse_result_type( kernel.target, self.forced_result_type),) @@ -114,18 +126,18 @@ class ScalarReductionOperation(ReductionOperation): class SumReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return 0 - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return operand1 + operand2 class ProductReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return 1 - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return operand1 * operand2 @@ -166,32 +178,144 @@ def get_ge_neutral(dtype): class MaxReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return get_ge_neutral(dtype) - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return var("max")(operand1, operand2) class MinReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return get_le_neutral(dtype) - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return var("min")(operand1, operand2) +# {{{ segmented reduction + +class _SegmentedScalarReductionOperation(ReductionOperation): + def __init__(self, **kwargs): + self.inner_reduction = self.base_reduction_class(**kwargs) + + @property + def arg_count(self): + return 2 + + def prefix(self, scalar_dtype, segment_flag_dtype): + return "loopy_segmented_%s_%s_%s" % (self.which, + scalar_dtype.numpy_dtype.type.__name__, + segment_flag_dtype.numpy_dtype.type.__name__) + + def neutral_element(self, scalar_dtype, segment_flag_dtype): + return SegmentedFunction(self, (scalar_dtype, segment_flag_dtype), "init")() + + def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype): + return (self.inner_reduction.result_dtypes(kernel, scalar_dtype) + + (segment_flag_dtype,)) + + def __str__(self): + return "segmented(%s)" % self.which + + def __hash__(self): + return hash(type(self)) + + def __eq__(self, other): + return type(self) == type(other) + + def __call__(self, dtypes, operand1, operand2): + return SegmentedFunction(self, dtypes, "update")(*(operand1 + operand2)) + + +class SegmentedSumReductionOperation(_SegmentedScalarReductionOperation): + base_reduction_class = SumReductionOperation + which = "sum" + op = "((%s) + (%s))" + + +class SegmentedProductReductionOperation(_SegmentedScalarReductionOperation): + base_reduction_class = ProductReductionOperation + op = "((%s) * (%s))" + which = "product" + + +class SegmentedFunction(FunctionIdentifier): + init_arg_names = ("reduction_op", "dtypes", "name") + + def __init__(self, reduction_op, dtypes, name): + """ + :arg dtypes: A :class:`tuple` of `(scalar_dtype, segment_flag_dtype)` + """ + self.reduction_op = reduction_op + self.dtypes = dtypes + self.name = name + + @property + def scalar_dtype(self): + return self.dtypes[0] + + @property + def segment_flag_dtype(self): + return self.dtypes[1] + + def __getinitargs__(self): + return (self.reduction_op, self.dtypes, self.name) + + +def get_segmented_function_preamble(kernel, func_id): + op = func_id.reduction_op + prefix = op.prefix(func_id.scalar_dtype, func_id.segment_flag_dtype) + + from pymbolic.mapper.c_code import CCodeMapper + + c_code_mapper = CCodeMapper() + + return (prefix, """ + inline %(scalar_t)s %(prefix)s_init(%(segment_flag_t)s *segment_flag_out) + { + *segment_flag_out = 0; + return %(neutral)s; + } + + inline %(scalar_t)s %(prefix)s_update( + %(scalar_t)s op1, %(segment_flag_t)s segment_flag1, + %(scalar_t)s op2, %(segment_flag_t)s segment_flag2, + %(segment_flag_t)s *segment_flag_out) + { + *segment_flag_out = segment_flag1 | segment_flag2; + return segment_flag2 ? op2 : %(combined)s; + } + """ % dict( + scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype), + prefix=prefix, + segment_flag_t=kernel.target.dtype_to_typename( + func_id.segment_flag_dtype), + neutral=c_code_mapper( + op.inner_reduction.neutral_element(func_id.scalar_dtype)), + combined=op.op % ("op1", "op2"), + )) + + +# }}} + + # {{{ argmin/argmax class _ArgExtremumReductionOperation(ReductionOperation): - def prefix(self, dtype): - return "loopy_arg%s_%s" % (self.which, dtype.numpy_dtype.type.__name__) + def prefix(self, scalar_dtype, index_dtype): + return "loopy_arg%s_%s_%s" % (self.which, + index_dtype.numpy_dtype.type.__name__, + scalar_dtype.numpy_dtype.type.__name__) + + def result_dtypes(self, kernel, scalar_dtype, index_dtype): + return (scalar_dtype, index_dtype) - def result_dtypes(self, kernel, dtype, inames): - return (dtype, kernel.index_dtype) + def neutral_element(self, scalar_dtype, index_dtype): + return ArgExtFunction(self, (scalar_dtype, index_dtype), "init")() - def neutral_element(self, dtype, inames): - return ArgExtFunction(self, dtype, "init", inames)() + def __str__(self): + return self.which def __hash__(self): return hash(type(self)) @@ -199,11 +323,12 @@ class _ArgExtremumReductionOperation(ReductionOperation): def __eq__(self, other): return type(self) == type(other) - def __call__(self, dtype, operand1, operand2, inames): - iname, = inames + @property + def arg_count(self): + return 2 - return ArgExtFunction(self, dtype, "update", inames)( - *(operand1 + (operand2, var(iname)))) + def __call__(self, dtypes, operand1, operand2): + return ArgExtFunction(self, dtypes, "update")(*(operand1 + operand2)) class ArgMaxReductionOperation(_ArgExtremumReductionOperation): @@ -219,21 +344,28 @@ class ArgMinReductionOperation(_ArgExtremumReductionOperation): class ArgExtFunction(FunctionIdentifier): - init_arg_names = ("reduction_op", "scalar_dtype", "name", "inames") + init_arg_names = ("reduction_op", "dtypes", "name") - def __init__(self, reduction_op, scalar_dtype, name, inames): + def __init__(self, reduction_op, dtypes, name): self.reduction_op = reduction_op - self.scalar_dtype = scalar_dtype + self.dtypes = dtypes self.name = name - self.inames = inames + + @property + def scalar_dtype(self): + return self.dtypes[0] + + @property + def index_dtype(self): + return self.dtypes[1] def __getinitargs__(self): - return (self.reduction_op, self.scalar_dtype, self.name, self.inames) + return (self.reduction_op, self.dtypes, self.name) def get_argext_preamble(kernel, func_id): op = func_id.reduction_op - prefix = op.prefix(func_id.scalar_dtype) + prefix = op.prefix(func_id.scalar_dtype, func_id.index_dtype) from pymbolic.mapper.c_code import CCodeMapper @@ -267,7 +399,7 @@ def get_argext_preamble(kernel, func_id): """ % dict( scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype), prefix=prefix, - index_t=kernel.target.dtype_to_typename(kernel.index_dtype), + index_t=kernel.target.dtype_to_typename(func_id.index_dtype), neutral=c_code_mapper(neutral(func_id.scalar_dtype)), comp=op.update_comparison, )) @@ -284,6 +416,8 @@ _REDUCTION_OPS = { "min": MinReductionOperation, "argmax": ArgMaxReductionOperation, "argmin": ArgMinReductionOperation, + "segmented(sum)": SegmentedSumReductionOperation, + "segmented(product)": SegmentedProductReductionOperation, } _REDUCTION_OP_PARSERS = [ @@ -333,9 +467,10 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes): from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_init" % op.prefix(func_id.scalar_dtype), + target_name="%s_init" % op.prefix( + func_id.scalar_dtype, func_id.index_dtype), result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.inames), + kernel, func_id.scalar_dtype, func_id.index_dtype), arg_dtypes=(), ) @@ -348,9 +483,10 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes): from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_update" % op.prefix(func_id.scalar_dtype), + target_name="%s_update" % op.prefix( + func_id.scalar_dtype, func_id.index_dtype), result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.inames), + kernel, func_id.scalar_dtype, func_id.index_dtype), arg_dtypes=( func_id.scalar_dtype, kernel.index_dtype, @@ -358,6 +494,42 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes): kernel.index_dtype), ) + elif isinstance(func_id, SegmentedFunction) and func_id.name == "init": + from loopy.target.opencl import OpenCLTarget + if not isinstance(kernel.target, OpenCLTarget): + raise LoopyError("only OpenCL supported for now") + + op = func_id.reduction_op + + from loopy.kernel.data import CallMangleInfo + return CallMangleInfo( + target_name="%s_init" % op.prefix( + func_id.scalar_dtype, func_id.segment_flag_dtype), + result_dtypes=op.result_dtypes( + kernel, func_id.scalar_dtype, func_id.segment_flag_dtype), + arg_dtypes=(), + ) + + elif isinstance(func_id, SegmentedFunction) and func_id.name == "update": + from loopy.target.opencl import OpenCLTarget + if not isinstance(kernel.target, OpenCLTarget): + raise LoopyError("only OpenCL supported for now") + + op = func_id.reduction_op + + from loopy.kernel.data import CallMangleInfo + return CallMangleInfo( + target_name="%s_update" % op.prefix( + func_id.scalar_dtype, func_id.segment_flag_dtype), + result_dtypes=op.result_dtypes( + kernel, func_id.scalar_dtype, func_id.segment_flag_dtype), + arg_dtypes=( + func_id.scalar_dtype, + func_id.segment_flag_dtype, + func_id.scalar_dtype, + func_id.segment_flag_dtype), + ) + return None @@ -371,4 +543,10 @@ def reduction_preamble_generator(preamble_info): yield get_argext_preamble(preamble_info.kernel, func.name) + elif isinstance(func.name, SegmentedFunction): + if not isinstance(preamble_info.kernel.target, OpenCLTarget): + raise LoopyError("only OpenCL supported for now") + + yield get_segmented_function_preamble(preamble_info.kernel, func.name) + # vim: fdm=marker diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 2b6d97c38..a5c9b0e4f 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -26,7 +26,7 @@ THE SOFTWARE. import six from loopy.diagnostic import ( LoopyError, WriteRaceConditionWarning, warn_with_kernel, - LoopyAdvisory, DependencyTypeInferenceFailure) + LoopyAdvisory) import islpy as isl @@ -97,7 +97,12 @@ def check_reduction_iname_uniqueness(kernel): iname_to_nonsimultaneous_reduction_count = {} def map_reduction(expr, rec): - rec(expr.expr) + if expr.is_plain_tuple: + for sub_expr in expr.exprs: + rec(sub_expr) + else: + rec(expr.exprs) + for iname in expr.inames: iname_to_reduction_count[iname] = ( iname_to_reduction_count.get(iname, 0) + 1) @@ -295,12 +300,19 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): var_name_gen = kernel.get_var_name_generator() new_temporary_variables = kernel.temporary_variables.copy() - from loopy.type_inference import TypeInferenceMapper - type_inf_mapper = TypeInferenceMapper(kernel) + # {{{ helpers + + def _strip_if_scalar(reference, val): + if len(reference) == 1: + return val[0] + else: + return val + + # }}} # {{{ sequential - def map_reduction_seq(expr, rec, nresults, arg_dtype, + def map_reduction_seq(expr, rec, nresults, arg_dtypes, reduction_dtypes): outer_insn_inames = temp_kernel.insn_inames(insn) @@ -328,7 +340,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): within_inames=outer_insn_inames - frozenset(expr.inames), within_inames_is_final=insn.within_inames_is_final, depends_on=frozenset(), - expression=expr.operation.neutral_element(arg_dtype, expr.inames)) + expression=expr.operation.neutral_element(*arg_dtypes)) generated_insns.append(init_insn) @@ -343,9 +355,9 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): id=update_id, assignees=acc_vars, expression=expr.operation( - arg_dtype, - acc_vars if len(acc_vars) > 1 else acc_vars[0], - expr.expr, expr.inames), + arg_dtypes, + _strip_if_scalar(acc_vars, acc_vars), + _strip_if_scalar(acc_vars, expr.exprs)), depends_on=frozenset([init_insn.id]) | insn.depends_on, within_inames=update_insn_iname_deps, within_inames_is_final=insn.within_inames_is_final) @@ -382,7 +394,15 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): v[iname].lt_set(v[0] + size)).get_basic_sets() return bs - def map_reduction_local(expr, rec, nresults, arg_dtype, + def _make_slab_set_from_range(iname, lbound, ubound): + v = isl.make_zero_and_vars([iname]) + bs, = ( + v[iname].ge_set(v[0] + lbound) + & + v[iname].lt_set(v[0] + ubound)).get_basic_sets() + return bs + + def map_reduction_local(expr, rec, nresults, arg_dtypes, reduction_dtypes): red_iname, = expr.inames @@ -441,7 +461,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): base_iname_deps = outer_insn_inames - frozenset(expr.inames) - neutral = expr.operation.neutral_element(arg_dtype, expr.inames) + neutral = expr.operation.neutral_element(*arg_dtypes) init_id = insn_id_gen("%s_%s_init" % (insn.id, red_iname)) init_insn = make_assignment( @@ -455,12 +475,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): depends_on=frozenset()) generated_insns.append(init_insn) - def _strip_if_scalar(c): - if len(acc_vars) == 1: - return c[0] - else: - return c - init_neutral_id = insn_id_gen("%s_%s_init_neutral" % (insn.id, red_iname)) init_neutral_insn = make_assignment( id=init_neutral_id, @@ -478,9 +492,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): acc_var[outer_local_iname_vars + (var(red_iname),)] for acc_var in acc_vars), expression=expr.operation( - arg_dtype, - _strip_if_scalar(tuple(var(nvn) for nvn in neutral_var_names)), - expr.expr, expr.inames), + arg_dtypes, + _strip_if_scalar( + expr.exprs, + tuple(var(nvn) for nvn in neutral_var_names)), + _strip_if_scalar(expr.exprs, expr.exprs)), within_inames=( (outer_insn_inames - frozenset(expr.inames)) | frozenset([red_iname])), @@ -513,17 +529,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): acc_var[outer_local_iname_vars + (var(stage_exec_iname),)] for acc_var in acc_vars), expression=expr.operation( - arg_dtype, - _strip_if_scalar(tuple( + arg_dtypes, + _strip_if_scalar(acc_vars, tuple( acc_var[ outer_local_iname_vars + (var(stage_exec_iname),)] for acc_var in acc_vars)), - _strip_if_scalar(tuple( + _strip_if_scalar(acc_vars, tuple( acc_var[ outer_local_iname_vars + ( var(stage_exec_iname) + new_size,)] - for acc_var in acc_vars)), - expr.inames), + for acc_var in acc_vars))), within_inames=( base_iname_deps | frozenset([stage_exec_iname])), within_inames_is_final=insn.within_inames_is_final, @@ -554,24 +569,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): # Only expand one level of reduction at a time, going from outermost to # innermost. Otherwise we get the (iname + insn) dependencies wrong. - try: - arg_dtype = type_inf_mapper(expr.expr) - except DependencyTypeInferenceFailure: - if unknown_types_ok: - arg_dtype = lp.auto - - reduction_dtypes = (lp.auto,)*nresults - - else: - raise LoopyError("failed to determine type of accumulator for " - "reduction '%s'" % expr) - else: - arg_dtype = arg_dtype.with_target(kernel.target) - - reduction_dtypes = expr.operation.result_dtypes( - kernel, arg_dtype, expr.inames) - reduction_dtypes = tuple( - dt.with_target(kernel.target) for dt in reduction_dtypes) + from loopy.type_inference import ( + infer_arg_and_reduction_dtypes_for_reduction_expression) + arg_dtypes, reduction_dtypes = ( + infer_arg_and_reduction_dtypes_for_reduction_expression( + temp_kernel, expr, unknown_types_ok)) outer_insn_inames = temp_kernel.insn_inames(insn) bad_inames = frozenset(expr.inames) & outer_insn_inames @@ -621,10 +623,10 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): if n_sequential: assert n_local_par == 0 - return map_reduction_seq(expr, rec, nresults, arg_dtype, + return map_reduction_seq(expr, rec, nresults, arg_dtypes, reduction_dtypes) elif n_local_par: - return map_reduction_local(expr, rec, nresults, arg_dtype, + return map_reduction_local(expr, rec, nresults, arg_dtypes, reduction_dtypes) else: from loopy.diagnostic import warn_with_kernel diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 50c891be4..8876e2950 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -95,7 +95,10 @@ class IdentityMapperMixin(object): new_inames.append(new_sym_iname.name) return Reduction( - expr.operation, tuple(new_inames), self.rec(expr.expr, *args), + expr.operation, tuple(new_inames), + (tuple(self.rec(e, *args) for e in expr.exprs) + if expr.is_plain_tuple + else self.rec(expr.exprs, *args)), allow_simultaneous=expr.allow_simultaneous) def map_tagged_variable(self, expr, *args): @@ -144,7 +147,11 @@ class WalkMapper(WalkMapperBase): if not self.visit(expr): return - self.rec(expr.expr, *args) + if expr.is_plain_tuple: + for sub_expr in expr.exprs: + self.rec(sub_expr, *args) + else: + self.rec(expr.exprs, *args) map_tagged_variable = WalkMapperBase.map_variable @@ -162,7 +169,10 @@ class CallbackMapper(CallbackMapperBase, IdentityMapper): class CombineMapper(CombineMapperBase): def map_reduction(self, expr): - return self.rec(expr.expr) + if expr.is_plain_tuple: + return self.combine(self.rec(sub_expr) for sub_expr in expr.exprs) + else: + return self.rec(expr.exprs) map_linear_subscript = CombineMapperBase.map_subscript @@ -192,9 +202,13 @@ class StringifyMapper(StringifyMapperBase): return "loc.%d" % expr.index def map_reduction(self, expr, prec): + from pymbolic.mapper.stringifier import PREC_NONE return "%sreduce(%s, [%s], %s)" % ( "simul_" if expr.allow_simultaneous else "", - expr.operation, ", ".join(expr.inames), expr.expr) + expr.operation, ", ".join(expr.inames), + (", ".join(self.rec(e, PREC_NONE) for e in expr.exprs) + if expr.is_plain_tuple + else self.rec(expr.exprs, PREC_NONE))) def map_tagged_variable(self, expr, prec): return "%s$%s" % (expr.name, expr.tag) @@ -224,8 +238,17 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): or type(expr.operation) != type(other.operation) # noqa ): return [] + if expr.is_plain_tuple != other.is_plain_tuple: + return [] + + if expr.is_plain_tuple: + for sub_expr_l, sub_expr_r in zip(expr.exprs, other.exprs): + unis = self.rec(sub_expr_l, sub_expr_r, unis) + if not unis: + break + return unis - return self.rec(expr.expr, other.expr, unis) + return self.rec(expr.exprs, other.exprs, unis) def map_tagged_variable(self, expr, other, urecs): new_uni_record = self.unification_record_from_equation( @@ -258,8 +281,11 @@ class DependencyMapper(DependencyMapperBase): self.rec(child, *args) for child in expr.parameters) def map_reduction(self, expr): - return (self.rec(expr.expr) - - set(p.Variable(iname) for iname in expr.inames)) + if expr.is_plain_tuple: + deps = self.combine(self.rec(sub_expr) for sub_expr in expr.exprs) + else: + deps = self.rec(expr.exprs) + return deps - set(p.Variable(iname) for iname in expr.inames) def map_tagged_variable(self, expr): return set([expr]) @@ -428,7 +454,7 @@ class TaggedVariable(p.Variable): class Reduction(p.Expression): - """Represents a reduction operation on :attr:`expr` + """Represents a reduction operation on :attr:`exprs` across :attr:`inames`. .. attribute:: operation @@ -440,10 +466,12 @@ class Reduction(p.Expression): a list of inames across which reduction on :attr:`expr` is being carried out. - .. attribute:: expr + .. attribute:: exprs - The expression (as a :class:`pymbolic.primitives.Expression`) - on which reduction is performed. + A (tuple-typed) expression which currently may be one of + * a :class:`tuple` of :class:`pymbolic.primitives.Expression`, or + * a :class:`loopy.symbolic.Reduction`, or + * a substitution rule invocation. .. attribute:: allow_simultaneous @@ -451,9 +479,9 @@ class Reduction(p.Expression): in precisely one reduction, to avoid mis-nesting errors. """ - init_arg_names = ("operation", "inames", "expr", "allow_simultaneous") + init_arg_names = ("operation", "inames", "exprs", "allow_simultaneous") - def __init__(self, operation, inames, expr, allow_simultaneous=False): + def __init__(self, operation, inames, exprs, allow_simultaneous=False): if isinstance(inames, str): inames = tuple(iname.strip() for iname in inames.split(",")) @@ -475,30 +503,40 @@ class Reduction(p.Expression): from loopy.library.reduction import parse_reduction_op operation = parse_reduction_op(operation) + if not isinstance(exprs, tuple): + exprs = (exprs,) + from loopy.library.reduction import ReductionOperation assert isinstance(operation, ReductionOperation) self.operation = operation self.inames = inames - self.expr = expr + self.exprs = exprs self.allow_simultaneous = allow_simultaneous def __getinitargs__(self): - return (self.operation, self.inames, self.expr, self.allow_simultaneous) + return (self.operation, self.inames, self.exprs, self.allow_simultaneous) def get_hash(self): - return hash((self.__class__, self.operation, self.inames, - self.expr)) + return hash((self.__class__, self.operation, self.inames, self.exprs)) def is_equal(self, other): return (other.__class__ == self.__class__ and other.operation == self.operation and other.inames == self.inames - and other.expr == self.expr) + and other.exprs == self.exprs) def stringifier(self): return StringifyMapper + @property + def is_plain_tuple(self): + """ + :return: True if the reduction expression is a tuple, False if otherwise + (the inner expression will still have a tuple type) + """ + return isinstance(self.exprs, tuple) + @property @memoize_method def inames_set(self): @@ -924,7 +962,7 @@ class FunctionToPrimitiveMapper(IdentityMapper): turns those into the actual pymbolic primitives used for that. """ - def _parse_reduction(self, operation, inames, red_expr, + def _parse_reduction(self, operation, inames, red_exprs, allow_simultaneous=False): if isinstance(inames, p.Variable): inames = (inames,) @@ -941,7 +979,7 @@ class FunctionToPrimitiveMapper(IdentityMapper): processed_inames.append(iname.name) - return Reduction(operation, tuple(processed_inames), red_expr, + return Reduction(operation, tuple(processed_inames), red_exprs, allow_simultaneous=allow_simultaneous) def map_call(self, expr): @@ -966,15 +1004,13 @@ class FunctionToPrimitiveMapper(IdentityMapper): raise TypeError("cse takes two arguments") elif name in ["reduce", "simul_reduce"]: - if len(expr.parameters) == 3: - operation, inames, red_expr = expr.parameters - - if not isinstance(operation, p.Variable): - raise TypeError("operation argument to reduce() " - "must be a symbol") + if len(expr.parameters) >= 3: + operation, inames = expr.parameters[:2] + red_exprs = expr.parameters[2:] - operation = parse_reduction_op(operation.name) - return self._parse_reduction(operation, inames, self.rec(red_expr), + operation = parse_reduction_op(str(operation)) + return self._parse_reduction(operation, inames, + tuple(self.rec(red_expr) for red_expr in red_exprs), allow_simultaneous=(name == "simul_reduce")) else: raise TypeError("invalid 'reduce' calling sequence") @@ -991,12 +1027,17 @@ class FunctionToPrimitiveMapper(IdentityMapper): operation = parse_reduction_op(name) if operation: - if len(expr.parameters) != 2: + # arg_count counts arguments but not inames + if len(expr.parameters) != 1 + operation.arg_count: raise RuntimeError("invalid invocation of " - "reduction operation '%s'" % expr.function.name) + "reduction operation '%s': expected %d arguments, " + "got %d instead" % (expr.function.name, + 1 + operation.arg_count, + len(expr.parameters))) - inames, red_expr = expr.parameters - return self._parse_reduction(operation, inames, self.rec(red_expr)) + inames = expr.parameters[0] + red_exprs = tuple(self.rec(param) for param in expr.parameters[1:]) + return self._parse_reduction(operation, inames, red_exprs) else: return IdentityMapper.map_call(self, expr) @@ -1385,7 +1426,10 @@ class IndexVariableFinder(CombineMapper): return result def map_reduction(self, expr): - result = self.rec(expr.expr) + if expr.is_plain_tuple: + result = self.combine(self.rec(sub_expr) for sub_expr in expr.exprs) + else: + result = self.rec(expr.exprs) if not (expr.inames_set & result): raise RuntimeError("reduction '%s' does not depend on " diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 575311b11..a1948b615 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -683,7 +683,8 @@ def set_temporary_scope(kernel, temp_var_names, scope): # {{{ reduction_arg_to_subst_rule -def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=None): +def reduction_arg_to_subst_rule( + knl, inames, insn_match=None, subst_rule_name=None, arg_number=0): if isinstance(inames, str): inames = [s.strip() for s in inames.split(",")] @@ -695,10 +696,15 @@ def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=No def map_reduction(expr, rec, nresults=1): if frozenset(expr.inames) != inames_set: + if expr.is_plain_tuple: + rec_result = tuple(rec(sub_expr) for sub_expr in expr.exprs) + else: + rec_result = rec(expr.exprs) + return type(expr)( operation=expr.operation, inames=expr.inames, - expr=rec(expr.expr), + exprs=rec_result, allow_simultaneous=expr.allow_simultaneous) if subst_rule_name is None: @@ -711,19 +717,27 @@ def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=No raise LoopyError("substitution rule '%s' already exists" % my_subst_rule_name) + if not expr.is_plain_tuple: + raise NotImplemented("non-tuple reduction arguments not supported") + from loopy.kernel.data import SubstitutionRule substs[my_subst_rule_name] = SubstitutionRule( name=my_subst_rule_name, arguments=tuple(inames), - expression=expr.expr) + expression=expr.exprs[arg_number]) from pymbolic import var iname_vars = [var(iname) for iname in inames] + new_exprs = tuple(sub_expr + if i != arg_number + else var(my_subst_rule_name)(*iname_vars) + for i, sub_expr in enumerate(expr.exprs)) + return type(expr)( operation=expr.operation, inames=expr.inames, - expr=var(my_subst_rule_name)(*iname_vars), + exprs=new_exprs, allow_simultaneous=expr.allow_simultaneous) from loopy.symbolic import ReductionCallbackMapper diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index c35b50643..81db51a7e 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -145,7 +145,10 @@ class _InameSplitter(RuleAwareIdentityMapper): from loopy.symbolic import Reduction return Reduction(expr.operation, tuple(new_inames), - self.rec(expr.expr, expn_state), + (tuple(self.rec(sub_expr, expn_state) + for sub_expr in expr.exprs) + if expr.is_plain_tuple + else self.rec(expr.exprs, expn_state)), expr.allow_simultaneous) else: return super(_InameSplitter, self).map_reduction(expr, expn_state) @@ -1191,13 +1194,19 @@ class _ReductionSplitter(RuleAwareIdentityMapper): if self.direction == "in": return Reduction(expr.operation, tuple(leftover_inames), Reduction(expr.operation, tuple(self.inames), - self.rec(expr.expr, expn_state), + (tuple(self.rec(sub_expr, expn_state) + for sub_expr in expr.exprs) + if expr.is_plain_tuple + else self.rec(expr.exprs, expn_state)), expr.allow_simultaneous), expr.allow_simultaneous) elif self.direction == "out": return Reduction(expr.operation, tuple(self.inames), Reduction(expr.operation, tuple(leftover_inames), - self.rec(expr.expr, expn_state), + (tuple(self.rec(sub_expr, expn_state) + for sub_expr in expr.exprs) + if expr.is_plain_tuple + else self.rec(expr.exprs, expn_state)), expr.allow_simultaneous)) else: assert False @@ -1589,10 +1598,16 @@ class _ReductionInameUniquifier(RuleAwareIdentityMapper): from loopy.symbolic import Reduction return Reduction(expr.operation, tuple(new_inames), - self.rec( - SubstitutionMapper(make_subst_func(subst_dict))( - expr.expr), - expn_state), + (tuple(self.rec( + SubstitutionMapper(make_subst_func(subst_dict))( + sub_expr), + expn_state) + for sub_expr in expr.exprs) + if expr.is_plain_tuple + else self.rec( + SubstitutionMapper(make_subst_func(subst_dict))( + expr.exprs), + expn_state)), expr.allow_simultaneous) else: return super(_ReductionInameUniquifier, self).map_reduction( diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 4c1e423e9..cdba4a5cb 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -352,28 +352,30 @@ class TypeInferenceMapper(CombineMapper): return [self.kernel.index_dtype] def map_reduction(self, expr, return_tuple=False): - rec_result = self.rec(expr.expr) - - if rec_result: - rec_result, = rec_result - result = expr.operation.result_dtypes( - self.kernel, rec_result, expr.inames) - else: - result = expr.operation.result_dtypes( - self.kernel, None, expr.inames) + """ + :arg return_tuple: If *True*, treat the type of the reduction expression + as a tuple type. Otherwise, the number of expressions being reduced over + must equal 1, and the type of the first expression is returned. + """ + rec_results = tuple(self.rec(sub_expr) for sub_expr in expr.exprs) - if result is None: + if any(len(rec_result) == 0 for rec_result in rec_results): return [] if return_tuple: - return [result] + from itertools import product + return list( + expr.operation.result_dtypes(self.kernel, *product_element) + for product_element in product(*rec_results)) else: - if len(result) != 1 and not return_tuple: + if len(rec_results) != 1: raise LoopyError("reductions with more or fewer than one " "return value may only be used in direct assignments") - return [result[0]] + return list( + expr.operation.result_dtypes(self.kernel, rec_result)[0] + for rec_result in rec_results[0]) # }}} @@ -617,4 +619,38 @@ def infer_unknown_types(kernel, expect_completion=False): # }}} + +# {{{ reduction expression helper + +def infer_arg_and_reduction_dtypes_for_reduction_expression( + kernel, expr, unknown_types_ok): + arg_dtypes = [] + + type_inf_mapper = TypeInferenceMapper(kernel) + import loopy as lp + + for sub_expr in expr.exprs: + try: + arg_dtype = type_inf_mapper(sub_expr) + except DependencyTypeInferenceFailure: + if unknown_types_ok: + arg_dtype = lp.auto + else: + raise LoopyError("failed to determine type of accumulator for " + "reduction sub-expression '%s'" % sub_expr) + else: + arg_dtype = arg_dtype.with_target(kernel.target) + + arg_dtypes.append(arg_dtype) + + reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes) + reduction_dtypes = tuple( + dt.with_target(kernel.target) + if dt is not lp.auto else dt + for dt in reduction_dtypes) + + return tuple(arg_dtypes), reduction_dtypes + +# }}} + # vim: foldmethod=marker diff --git a/test/test_reduction.py b/test/test_reduction.py index 5887df7a6..1dd11b492 100644 --- a/test/test_reduction.py +++ b/test/test_reduction.py @@ -297,7 +297,7 @@ def test_argmax(ctx_factory): knl = lp.make_kernel( "{[i]: 0<=i<%d}" % n, """ - max_val, max_idx = argmax(i, fabs(a[i])) + max_val, max_idx = argmax(i, fabs(a[i]), i) """) knl = lp.add_and_infer_dtypes(knl, {"a": np.float32}) @@ -393,16 +393,24 @@ def test_double_sum_made_unique(ctx_factory): assert b.get() == ref -def test_parallel_multi_output_reduction(): +def test_parallel_multi_output_reduction(ctx_factory): knl = lp.make_kernel( "{[i]: 0<=i<128}", """ - max_val, max_indices = argmax(i, fabs(a[i])) + max_val, max_indices = argmax(i, fabs(a[i]), i) """) knl = lp.tag_inames(knl, dict(i="l.0")) + knl = lp.add_dtypes(knl, dict(a=np.float64)) knl = lp.realize_reduction(knl) - print(knl) - # TODO: Add functional test + + ctx = ctx_factory() + + with cl.CommandQueue(ctx) as queue: + a = np.random.rand(128) + out, (max_index, max_val) = knl(queue, a=a) + + assert max_val == np.max(a) + assert max_index == np.argmax(np.abs(a)) if __name__ == "__main__": -- GitLab