diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index d24b61c12e43cd16431c9727d8fb057319475633..f9648bde7dc4d685ca9daf63ecf15b69496c8651 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -36,15 +36,19 @@ class ReductionOperation(object): equality-comparable. """ - def result_dtypes(self, target, arg_dtype, inames): + def result_dtypes(self, target, *arg_dtypes): """ - :arg arg_dtype: may be None if not known + :arg arg_dtypes: may be None if not known :returns: None if not known, otherwise the returned type """ raise NotImplementedError - def neutral_element(self, dtype, inames): + @property + def arg_count(self): + raise NotImplementedError + + def neutral_element(self, *dtypes): raise NotImplementedError def __hash__(self): @@ -55,7 +59,7 @@ class ReductionOperation(object): # Force subclasses to override raise NotImplementedError - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): raise NotImplementedError def __ne__(self, other): @@ -87,7 +91,11 @@ class ScalarReductionOperation(ReductionOperation): """ self.forced_result_type = forced_result_type - def result_dtypes(self, kernel, arg_dtype, inames): + @property + def arg_count(self): + return 1 + + def result_dtypes(self, kernel, arg_dtype): if self.forced_result_type is not None: return (self.parse_result_type( kernel.target, self.forced_result_type),) @@ -114,18 +122,18 @@ class ScalarReductionOperation(ReductionOperation): class SumReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return 0 - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return operand1 + operand2 class ProductReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return 1 - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return operand1 * operand2 @@ -166,32 +174,144 @@ def get_ge_neutral(dtype): class MaxReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return get_ge_neutral(dtype) - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return var("max")(operand1, operand2) class MinReductionOperation(ScalarReductionOperation): - def neutral_element(self, dtype, inames): + def neutral_element(self, dtype): return get_le_neutral(dtype) - def __call__(self, dtype, operand1, operand2, inames): + def __call__(self, dtype, operand1, operand2): return var("min")(operand1, operand2) +# {{{ segmented reduction + +class _SegmentedScalarReductionOperation(ReductionOperation): + def __init__(self, **kwargs): + self.inner_reduction = self.base_reduction_class(**kwargs) + + @property + def arg_count(self): + return 2 + + def prefix(self, scalar_dtype, segment_flag_dtype): + return "loopy_segmented_%s_%s_%s" % (self.which, + scalar_dtype.numpy_dtype.type.__name__, + segment_flag_dtype.numpy_dtype.type.__name__) + + def neutral_element(self, scalar_dtype, segment_flag_dtype): + return SegmentedFunction(self, (scalar_dtype, segment_flag_dtype), "init")() + + def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype): + return (self.inner_reduction.result_dtypes(kernel, scalar_dtype) + + (segment_flag_dtype,)) + + def __str__(self): + return "segmented(%s)" % self.which + + def __hash__(self): + return hash(type(self)) + + def __eq__(self, other): + return type(self) == type(other) + + def __call__(self, dtypes, operand1, operand2): + return SegmentedFunction(self, dtypes, "update")(*(operand1 + operand2)) + + +class SegmentedSumReductionOperation(_SegmentedScalarReductionOperation): + base_reduction_class = SumReductionOperation + which = "sum" + op = "((%s) + (%s))" + + +class SegmentedProductReductionOperation(_SegmentedScalarReductionOperation): + base_reduction_class = ProductReductionOperation + op = "((%s) * (%s))" + which = "product" + + +class SegmentedFunction(FunctionIdentifier): + init_arg_names = ("reduction_op", "dtypes", "name") + + def __init__(self, reduction_op, dtypes, name): + """ + :arg dtypes: A :class:`tuple` of `(scalar_dtype, segment_flag_dtype)` + """ + self.reduction_op = reduction_op + self.dtypes = dtypes + self.name = name + + @property + def scalar_dtype(self): + return self.dtypes[0] + + @property + def segment_flag_dtype(self): + return self.dtypes[1] + + def __getinitargs__(self): + return (self.reduction_op, self.dtypes, self.name) + + +def get_segmented_function_preamble(kernel, func_id): + op = func_id.reduction_op + prefix = op.prefix(func_id.scalar_dtype, func_id.segment_flag_dtype) + + from pymbolic.mapper.c_code import CCodeMapper + + c_code_mapper = CCodeMapper() + + return (prefix, """ + inline %(scalar_t)s %(prefix)s_init(%(segment_flag_t)s *segment_flag_out) + { + *segment_flag_out = 0; + return %(neutral)s; + } + + inline %(scalar_t)s %(prefix)s_update( + %(scalar_t)s op1, %(segment_flag_t)s segment_flag1, + %(scalar_t)s op2, %(segment_flag_t)s segment_flag2, + %(segment_flag_t)s *segment_flag_out) + { + *segment_flag_out = segment_flag1 | segment_flag2; + return segment_flag2 ? op2 : %(combined)s; + } + """ % dict( + scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype), + prefix=prefix, + segment_flag_t=kernel.target.dtype_to_typename( + func_id.segment_flag_dtype), + neutral=c_code_mapper( + op.inner_reduction.neutral_element(func_id.scalar_dtype)), + combined=op.op % ("op1", "op2"), + )) + + +# }}} + + # {{{ argmin/argmax class _ArgExtremumReductionOperation(ReductionOperation): - def prefix(self, dtype): - return "loopy_arg%s_%s" % (self.which, dtype.numpy_dtype.type.__name__) + def prefix(self, scalar_dtype, index_dtype): + return "loopy_arg%s_%s_%s" % (self.which, + index_dtype.numpy_dtype.type.__name__, + scalar_dtype.numpy_dtype.type.__name__) + + def result_dtypes(self, kernel, scalar_dtype, index_dtype): + return (scalar_dtype, index_dtype) - def result_dtypes(self, kernel, dtype, inames): - return (dtype, kernel.index_dtype) + def neutral_element(self, scalar_dtype, index_dtype): + return ArgExtFunction(self, (scalar_dtype, index_dtype), "init")() - def neutral_element(self, dtype, inames): - return ArgExtFunction(self, dtype, "init", inames)() + def __str__(self): + return self.which def __hash__(self): return hash(type(self)) @@ -199,11 +319,12 @@ class _ArgExtremumReductionOperation(ReductionOperation): def __eq__(self, other): return type(self) == type(other) - def __call__(self, dtype, operand1, operand2, inames): - iname, = inames + @property + def arg_count(self): + return 2 - return ArgExtFunction(self, dtype, "update", inames)( - *(operand1 + (operand2, var(iname)))) + def __call__(self, dtypes, operand1, operand2): + return ArgExtFunction(self, dtypes, "update")(*(operand1 + operand2)) class ArgMaxReductionOperation(_ArgExtremumReductionOperation): @@ -219,21 +340,28 @@ class ArgMinReductionOperation(_ArgExtremumReductionOperation): class ArgExtFunction(FunctionIdentifier): - init_arg_names = ("reduction_op", "scalar_dtype", "name", "inames") + init_arg_names = ("reduction_op", "dtypes", "name") - def __init__(self, reduction_op, scalar_dtype, name, inames): + def __init__(self, reduction_op, dtypes, name): self.reduction_op = reduction_op - self.scalar_dtype = scalar_dtype + self.dtypes = dtypes self.name = name - self.inames = inames + + @property + def scalar_dtype(self): + return self.dtypes[0] + + @property + def index_dtype(self): + return self.dtypes[1] def __getinitargs__(self): - return (self.reduction_op, self.scalar_dtype, self.name, self.inames) + return (self.reduction_op, self.dtypes, self.name) def get_argext_preamble(kernel, func_id): op = func_id.reduction_op - prefix = op.prefix(func_id.scalar_dtype) + prefix = op.prefix(func_id.scalar_dtype, func_id.index_dtype) from pymbolic.mapper.c_code import CCodeMapper @@ -267,7 +395,7 @@ def get_argext_preamble(kernel, func_id): """ % dict( scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype), prefix=prefix, - index_t=kernel.target.dtype_to_typename(kernel.index_dtype), + index_t=kernel.target.dtype_to_typename(func_id.index_dtype), neutral=c_code_mapper(neutral(func_id.scalar_dtype)), comp=op.update_comparison, )) @@ -284,6 +412,8 @@ _REDUCTION_OPS = { "min": MinReductionOperation, "argmax": ArgMaxReductionOperation, "argmin": ArgMinReductionOperation, + "segmented(sum)": SegmentedSumReductionOperation, + "segmented(product)": SegmentedProductReductionOperation, } _REDUCTION_OP_PARSERS = [ @@ -325,32 +455,34 @@ def parse_reduction_op(name): def reduction_function_mangler(kernel, func_id, arg_dtypes): if isinstance(func_id, ArgExtFunction) and func_id.name == "init": - from loopy.target.opencl import OpenCLTarget - if not isinstance(kernel.target, OpenCLTarget): - raise LoopyError("only OpenCL supported for now") + from loopy.target.opencl import CTarget + if not isinstance(kernel.target, CTarget): + raise LoopyError("%s: only C-like targets supported for now" % func_id) op = func_id.reduction_op from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_init" % op.prefix(func_id.scalar_dtype), + target_name="%s_init" % op.prefix( + func_id.scalar_dtype, func_id.index_dtype), result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.inames), + kernel, func_id.scalar_dtype, func_id.index_dtype), arg_dtypes=(), ) elif isinstance(func_id, ArgExtFunction) and func_id.name == "update": - from loopy.target.opencl import OpenCLTarget - if not isinstance(kernel.target, OpenCLTarget): - raise LoopyError("only OpenCL supported for now") + from loopy.target.opencl import CTarget + if not isinstance(kernel.target, CTarget): + raise LoopyError("%s: only C-like targets supported for now" % func_id) op = func_id.reduction_op from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_update" % op.prefix(func_id.scalar_dtype), + target_name="%s_update" % op.prefix( + func_id.scalar_dtype, func_id.index_dtype), result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.inames), + kernel, func_id.scalar_dtype, func_id.index_dtype), arg_dtypes=( func_id.scalar_dtype, kernel.index_dtype, @@ -358,6 +490,42 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes): kernel.index_dtype), ) + elif isinstance(func_id, SegmentedFunction) and func_id.name == "init": + from loopy.target.opencl import CTarget + if not isinstance(kernel.target, CTarget): + raise LoopyError("%s: only C-like targets supported for now" % func_id) + + op = func_id.reduction_op + + from loopy.kernel.data import CallMangleInfo + return CallMangleInfo( + target_name="%s_init" % op.prefix( + func_id.scalar_dtype, func_id.segment_flag_dtype), + result_dtypes=op.result_dtypes( + kernel, func_id.scalar_dtype, func_id.segment_flag_dtype), + arg_dtypes=(), + ) + + elif isinstance(func_id, SegmentedFunction) and func_id.name == "update": + from loopy.target.opencl import CTarget + if not isinstance(kernel.target, CTarget): + raise LoopyError("%s: only C-like targets supported for now" % func_id) + + op = func_id.reduction_op + + from loopy.kernel.data import CallMangleInfo + return CallMangleInfo( + target_name="%s_update" % op.prefix( + func_id.scalar_dtype, func_id.segment_flag_dtype), + result_dtypes=op.result_dtypes( + kernel, func_id.scalar_dtype, func_id.segment_flag_dtype), + arg_dtypes=( + func_id.scalar_dtype, + func_id.segment_flag_dtype, + func_id.scalar_dtype, + func_id.segment_flag_dtype), + ) + return None @@ -371,4 +539,10 @@ def reduction_preamble_generator(preamble_info): yield get_argext_preamble(preamble_info.kernel, func.name) + elif isinstance(func.name, SegmentedFunction): + if not isinstance(preamble_info.kernel.target, OpenCLTarget): + raise LoopyError("only OpenCL supported for now") + + yield get_segmented_function_preamble(preamble_info.kernel, func.name) + # vim: fdm=marker diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 0d8e771954cf26cc11747e745946389420fa5e1b..17226b63addb9e2e30d556730aa326d2ed59128c 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -26,7 +26,7 @@ THE SOFTWARE. import six from loopy.diagnostic import ( LoopyError, WriteRaceConditionWarning, warn_with_kernel, - LoopyAdvisory, DependencyTypeInferenceFailure) + LoopyAdvisory) import islpy as isl @@ -98,6 +98,7 @@ def check_reduction_iname_uniqueness(kernel): def map_reduction(expr, rec): rec(expr.expr) + for iname in expr.inames: iname_to_reduction_count[iname] = ( iname_to_reduction_count.get(iname, 0) + 1) @@ -272,6 +273,191 @@ def find_temporary_scope(kernel): # {{{ rewrite reduction to imperative form + +# {{{ reduction utils + +def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel): + """ + Multi assignment function calls are currently lowered into OpenCL so that + the function call:: + + a, b = segmented_sum(x, y, z, w) + + becomes:: + + a = segmented_sum_mangled(x, y, z, w, &b). + + For OpenCL, the scope of "b" is significant, and the preamble generation + currently assumes the scope is always private. This function forces that to + be the case by introducing temporary assignments into the kernel. + """ + + insn_id_gen = kernel.get_instruction_id_generator() + var_name_gen = kernel.get_var_name_generator() + + new_or_updated_instructions = {} + new_temporaries = {} + + dep_map = dict( + (insn.id, insn.depends_on) for insn in kernel.instructions) + + inverse_dep_map = dict((insn.id, set()) for insn in kernel.instructions) + + import six + for insn_id, deps in six.iteritems(dep_map): + for dep in deps: + inverse_dep_map[dep].add(insn_id) + + del dep_map + + # {{{ utils + + def _add_to_no_sync_with(insn_id, new_no_sync_with_params): + insn = kernel.id_to_insn.get(insn_id) + insn = new_or_updated_instructions.get(insn_id, insn) + new_or_updated_instructions[insn_id] = ( + insn.copy( + no_sync_with=( + insn.no_sync_with | frozenset(new_no_sync_with_params)))) + + def _add_to_depends_on(insn_id, new_depends_on_params): + insn = kernel.id_to_insn.get(insn_id) + insn = new_or_updated_instructions.get(insn_id, insn) + new_or_updated_instructions[insn_id] = ( + insn.copy( + depends_on=insn.depends_on | frozenset(new_depends_on_params))) + + # }}} + + from loopy.kernel.instruction import CallInstruction + for insn in kernel.instructions: + if not isinstance(insn, CallInstruction): + continue + + if len(insn.assignees) <= 1: + continue + + assignees = insn.assignees + assignee_var_names = insn.assignee_var_names() + + new_assignees = [assignees[0]] + newly_added_assignments_ids = set() + needs_replacement = False + + last_added_insn_id = insn.id + + from loopy.kernel.data import temp_var_scope, TemporaryVariable + + FIRST_POINTER_ASSIGNEE_IDX = 1 # noqa + + for assignee_nr, assignee_var_name, assignee in zip( + range(FIRST_POINTER_ASSIGNEE_IDX, len(assignees)), + assignee_var_names[FIRST_POINTER_ASSIGNEE_IDX:], + assignees[FIRST_POINTER_ASSIGNEE_IDX:]): + + if ( + assignee_var_name in kernel.temporary_variables + and + (kernel.temporary_variables[assignee_var_name].scope + == temp_var_scope.PRIVATE)): + new_assignees.append(assignee) + continue + + needs_replacement = True + + # {{{ generate a new assignent instruction + + new_assignee_name = var_name_gen( + "{insn_id}_retval_{assignee_nr}" + .format(insn_id=insn.id, assignee_nr=assignee_nr)) + + new_assignment_id = insn_id_gen( + "{insn_id}_assign_retval_{assignee_nr}" + .format(insn_id=insn.id, assignee_nr=assignee_nr)) + + newly_added_assignments_ids.add(new_assignment_id) + + import loopy as lp + new_temporaries[new_assignee_name] = ( + TemporaryVariable( + name=new_assignee_name, + dtype=lp.auto, + scope=temp_var_scope.PRIVATE)) + + from pymbolic import var + new_assignee = var(new_assignee_name) + new_assignees.append(new_assignee) + + new_or_updated_instructions[new_assignment_id] = ( + make_assignment( + assignees=(assignee,), + expression=new_assignee, + id=new_assignment_id, + depends_on=frozenset([last_added_insn_id]), + depends_on_is_final=True, + no_sync_with=( + insn.no_sync_with | frozenset([(insn.id, "any")])), + predicates=insn.predicates, + within_inames=insn.within_inames)) + + last_added_insn_id = new_assignment_id + + # }}} + + if not needs_replacement: + continue + + # {{{ update originating instruction + + orig_insn = new_or_updated_instructions.get(insn.id, insn) + + new_or_updated_instructions[insn.id] = ( + orig_insn.copy(assignees=tuple(new_assignees))) + + _add_to_no_sync_with(insn.id, + [(id, "any") for id in newly_added_assignments_ids]) + + # }}} + + # {{{ squash spurious memory dependencies amongst new assignments + + for new_insn_id in newly_added_assignments_ids: + _add_to_no_sync_with(new_insn_id, + [(id, "any") + for id in newly_added_assignments_ids + if id != new_insn_id]) + + # }}} + + # {{{ update instructions that depend on the originating instruction + + for inverse_dep in inverse_dep_map[insn.id]: + _add_to_depends_on(inverse_dep, newly_added_assignments_ids) + + for insn_id, scope in ( + new_or_updated_instructions[inverse_dep].no_sync_with): + if insn_id == insn.id: + _add_to_no_sync_with( + inverse_dep, + [(id, scope) for id in newly_added_assignments_ids]) + + # }}} + + new_temporary_variables = kernel.temporary_variables.copy() + new_temporary_variables.update(new_temporaries) + + new_instructions = ( + list(new_or_updated_instructions.values()) + + list(insn + for insn in kernel.instructions + if insn.id not in new_or_updated_instructions)) + + return kernel.copy(temporary_variables=new_temporary_variables, + instructions=new_instructions) + +# }}} + + def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): """Rewrites reductions into their imperative form. With *insn_id_filter* specified, operate only on the instruction with an instruction id matching @@ -295,12 +481,52 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): var_name_gen = kernel.get_var_name_generator() new_temporary_variables = kernel.temporary_variables.copy() - from loopy.type_inference import TypeInferenceMapper - type_inf_mapper = TypeInferenceMapper(kernel) + # {{{ helpers + + def _strip_if_scalar(reference, val): + if len(reference) == 1: + return val[0] + else: + return val + + def expand_inner_reduction(id, expr, nresults, depends_on, within_inames, + within_inames_is_final): + from pymbolic.primitives import Call + from loopy.symbolic import Reduction + assert isinstance(expr, (Call, Reduction)) + + temp_var_names = [ + var_name_gen(id + "_arg" + str(i)) + for i in range(nresults)] + + for name in temp_var_names: + from loopy.kernel.data import TemporaryVariable, temp_var_scope + new_temporary_variables[name] = TemporaryVariable( + name=name, + shape=(), + dtype=lp.auto, + scope=temp_var_scope.PRIVATE) + + from pymbolic import var + temp_vars = tuple(var(n) for n in temp_var_names) + + call_insn = make_assignment( + id=id, + assignees=temp_vars, + expression=expr, + depends_on=depends_on, + within_inames=within_inames, + within_inames_is_final=within_inames_is_final) + + generated_insns.append(call_insn) + + return temp_vars + + # }}} # {{{ sequential - def map_reduction_seq(expr, rec, nresults, arg_dtype, + def map_reduction_seq(expr, rec, nresults, arg_dtypes, reduction_dtypes): outer_insn_inames = temp_kernel.insn_inames(insn) @@ -328,7 +554,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): within_inames=outer_insn_inames - frozenset(expr.inames), within_inames_is_final=insn.within_inames_is_final, depends_on=frozenset(), - expression=expr.operation.neutral_element(arg_dtype, expr.inames)) + expression=expr.operation.neutral_element(*arg_dtypes)) generated_insns.append(init_insn) @@ -339,14 +565,32 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): if insn.within_inames_is_final: update_insn_iname_deps = insn.within_inames | set(expr.inames) + reduction_insn_depends_on = set([init_id]) + + if nresults > 1 and not isinstance(expr.expr, tuple): + get_args_insn_id = insn_id_gen( + "%s_%s_get" % (insn.id, "_".join(expr.inames))) + + reduction_expr = expand_inner_reduction( + id=get_args_insn_id, + expr=expr.expr, + nresults=nresults, + depends_on=insn.depends_on, + within_inames=update_insn_iname_deps, + within_inames_is_final=insn.within_inames_is_final) + + reduction_insn_depends_on.add(get_args_insn_id) + else: + reduction_expr = expr.expr + reduction_insn = make_assignment( id=update_id, assignees=acc_vars, expression=expr.operation( - arg_dtype, - acc_vars if len(acc_vars) > 1 else acc_vars[0], - expr.expr, expr.inames), - depends_on=frozenset([init_insn.id]) | insn.depends_on, + arg_dtypes, + _strip_if_scalar(acc_vars, acc_vars), + reduction_expr), + depends_on=frozenset(reduction_insn_depends_on) | insn.depends_on, within_inames=update_insn_iname_deps, within_inames_is_final=insn.within_inames_is_final) @@ -382,7 +626,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): v[iname].lt_set(v[0] + size)).get_basic_sets() return bs - def map_reduction_local(expr, rec, nresults, arg_dtype, + def map_reduction_local(expr, rec, nresults, arg_dtypes, reduction_dtypes): red_iname, = expr.inames @@ -441,7 +685,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): base_iname_deps = outer_insn_inames - frozenset(expr.inames) - neutral = expr.operation.neutral_element(arg_dtype, expr.inames) + neutral = expr.operation.neutral_element(*arg_dtypes) init_id = insn_id_gen("%s_%s_init" % (insn.id, red_iname)) init_insn = make_assignment( @@ -455,12 +699,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): depends_on=frozenset()) generated_insns.append(init_insn) - def _strip_if_scalar(c): - if len(acc_vars) == 1: - return c[0] - else: - return c - init_neutral_id = insn_id_gen("%s_%s_init_neutral" % (insn.id, red_iname)) init_neutral_insn = make_assignment( id=init_neutral_id, @@ -471,6 +709,26 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): depends_on=frozenset()) generated_insns.append(init_neutral_insn) + transfer_depends_on = set([init_neutral_id, init_id]) + + if nresults > 1 and not isinstance(expr.expr, tuple): + get_args_insn_id = insn_id_gen( + "%s_%s_get" % (insn.id, red_iname)) + + reduction_expr = expand_inner_reduction( + id=get_args_insn_id, + expr=expr.expr, + nresults=nresults, + depends_on=insn.depends_on, + within_inames=( + (outer_insn_inames - frozenset(expr.inames)) + | frozenset([red_iname])), + within_inames_is_final=insn.within_inames_is_final) + + transfer_depends_on.add(get_args_insn_id) + else: + reduction_expr = expr.expr + transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname)) transfer_insn = make_assignment( id=transfer_id, @@ -478,15 +736,18 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): acc_var[outer_local_iname_vars + (var(red_iname),)] for acc_var in acc_vars), expression=expr.operation( - arg_dtype, - _strip_if_scalar(tuple(var(nvn) for nvn in neutral_var_names)), - expr.expr, expr.inames), + arg_dtypes, + _strip_if_scalar( + neutral_var_names, + tuple(var(nvn) for nvn in neutral_var_names)), + reduction_expr), within_inames=( (outer_insn_inames - frozenset(expr.inames)) | frozenset([red_iname])), within_inames_is_final=insn.within_inames_is_final, - depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on, - no_sync_with=frozenset([(init_id, "any")])) + depends_on=frozenset(transfer_depends_on) | insn.depends_on, + no_sync_with=frozenset( + [(insn_id, "any") for insn_id in transfer_depends_on])) generated_insns.append(transfer_insn) cur_size = 1 @@ -498,7 +759,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): istage = 0 while cur_size > 1: - new_size = cur_size // 2 assert new_size * 2 == cur_size @@ -513,17 +773,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): acc_var[outer_local_iname_vars + (var(stage_exec_iname),)] for acc_var in acc_vars), expression=expr.operation( - arg_dtype, - _strip_if_scalar(tuple( + arg_dtypes, + _strip_if_scalar(acc_vars, tuple( acc_var[ outer_local_iname_vars + (var(stage_exec_iname),)] for acc_var in acc_vars)), - _strip_if_scalar(tuple( + _strip_if_scalar(acc_vars, tuple( acc_var[ outer_local_iname_vars + ( var(stage_exec_iname) + new_size,)] - for acc_var in acc_vars)), - expr.inames), + for acc_var in acc_vars))), within_inames=( base_iname_deps | frozenset([stage_exec_iname])), within_inames_is_final=insn.within_inames_is_final, @@ -554,24 +813,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): # Only expand one level of reduction at a time, going from outermost to # innermost. Otherwise we get the (iname + insn) dependencies wrong. - try: - arg_dtype = type_inf_mapper(expr.expr) - except DependencyTypeInferenceFailure: - if unknown_types_ok: - arg_dtype = lp.auto - - reduction_dtypes = (lp.auto,)*nresults - - else: - raise LoopyError("failed to determine type of accumulator for " - "reduction '%s'" % expr) - else: - arg_dtype = arg_dtype.with_target(kernel.target) - - reduction_dtypes = expr.operation.result_dtypes( - kernel, arg_dtype, expr.inames) - reduction_dtypes = tuple( - dt.with_target(kernel.target) for dt in reduction_dtypes) + from loopy.type_inference import ( + infer_arg_and_reduction_dtypes_for_reduction_expression) + arg_dtypes, reduction_dtypes = ( + infer_arg_and_reduction_dtypes_for_reduction_expression( + temp_kernel, expr, unknown_types_ok)) outer_insn_inames = temp_kernel.insn_inames(insn) bad_inames = frozenset(expr.inames) & outer_insn_inames @@ -621,10 +867,10 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): if n_sequential: assert n_local_par == 0 - return map_reduction_seq(expr, rec, nresults, arg_dtype, + return map_reduction_seq(expr, rec, nresults, arg_dtypes, reduction_dtypes) elif n_local_par: - return map_reduction_local(expr, rec, nresults, arg_dtype, + return map_reduction_local(expr, rec, nresults, arg_dtypes, reduction_dtypes) else: from loopy.diagnostic import warn_with_kernel @@ -739,6 +985,10 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): kernel = lp.tag_inames(kernel, new_iname_tags) + kernel = ( + _hackily_ensure_multi_assignment_return_values_are_scoped_private( + kernel)) + return kernel # }}} diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 50c891be476810887720c4e13c9659966b431f5d..f1a494f30d469511817d204c0476ff79abe00e3b 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -95,7 +95,8 @@ class IdentityMapperMixin(object): new_inames.append(new_sym_iname.name) return Reduction( - expr.operation, tuple(new_inames), self.rec(expr.expr, *args), + expr.operation, tuple(new_inames), + self.rec(expr.expr, *args), allow_simultaneous=expr.allow_simultaneous) def map_tagged_variable(self, expr, *args): @@ -192,9 +193,12 @@ class StringifyMapper(StringifyMapperBase): return "loc.%d" % expr.index def map_reduction(self, expr, prec): + from pymbolic.mapper.stringifier import PREC_NONE + return "%sreduce(%s, [%s], %s)" % ( "simul_" if expr.allow_simultaneous else "", - expr.operation, ", ".join(expr.inames), expr.expr) + expr.operation, ", ".join(expr.inames), + self.rec(expr.expr, PREC_NONE)) def map_tagged_variable(self, expr, prec): return "%s$%s" % (expr.name, expr.tag) @@ -258,8 +262,8 @@ class DependencyMapper(DependencyMapperBase): self.rec(child, *args) for child in expr.parameters) def map_reduction(self, expr): - return (self.rec(expr.expr) - - set(p.Variable(iname) for iname in expr.inames)) + deps = self.rec(expr.expr) + return deps - set(p.Variable(iname) for iname in expr.inames) def map_tagged_variable(self, expr): return set([expr]) @@ -428,7 +432,7 @@ class TaggedVariable(p.Variable): class Reduction(p.Expression): - """Represents a reduction operation on :attr:`expr` + """Represents a reduction operation on :attr:`exprs` across :attr:`inames`. .. attribute:: operation @@ -442,8 +446,11 @@ class Reduction(p.Expression): .. attribute:: expr - The expression (as a :class:`pymbolic.primitives.Expression`) - on which reduction is performed. + An expression which may have tuple type. If the expression has tuple + type, it must be one of the following: + * a :class:`tuple` of :class:`pymbolic.primitives.Expression`, or + * a :class:`loopy.symbolic.Reduction`, or + * a function call or substitution rule invocation. .. attribute:: allow_simultaneous @@ -478,6 +485,22 @@ class Reduction(p.Expression): from loopy.library.reduction import ReductionOperation assert isinstance(operation, ReductionOperation) + from loopy.diagnostic import LoopyError + + if operation.arg_count > 1: + from pymbolic.primitives import Call + + if not isinstance(expr, (tuple, Reduction, Call)): + raise LoopyError("reduction argument must be one of " + "a tuple, reduction, or call; " + "got '%s'" % type(expr).__name__) + else: + # Sanity checks + if isinstance(expr, tuple): + raise LoopyError("got a tuple argument to a scalar reduction") + elif isinstance(expr, Reduction) and expr.is_tuple_typed: + raise LoopyError("got a tuple typed argument to a scalar reduction") + self.operation = operation self.inames = inames self.expr = expr @@ -487,8 +510,7 @@ class Reduction(p.Expression): return (self.operation, self.inames, self.expr, self.allow_simultaneous) def get_hash(self): - return hash((self.__class__, self.operation, self.inames, - self.expr)) + return hash((self.__class__, self.operation, self.inames, self.expr)) def is_equal(self, other): return (other.__class__ == self.__class__ @@ -499,6 +521,10 @@ class Reduction(p.Expression): def stringifier(self): return StringifyMapper + @property + def is_tuple_typed(self): + return self.operation.arg_count > 1 + @property @memoize_method def inames_set(self): @@ -924,7 +950,7 @@ class FunctionToPrimitiveMapper(IdentityMapper): turns those into the actual pymbolic primitives used for that. """ - def _parse_reduction(self, operation, inames, red_expr, + def _parse_reduction(self, operation, inames, red_exprs, allow_simultaneous=False): if isinstance(inames, p.Variable): inames = (inames,) @@ -941,7 +967,10 @@ class FunctionToPrimitiveMapper(IdentityMapper): processed_inames.append(iname.name) - return Reduction(operation, tuple(processed_inames), red_expr, + if len(red_exprs) == 1: + red_exprs = red_exprs[0] + + return Reduction(operation, tuple(processed_inames), red_exprs, allow_simultaneous=allow_simultaneous) def map_call(self, expr): @@ -966,15 +995,14 @@ class FunctionToPrimitiveMapper(IdentityMapper): raise TypeError("cse takes two arguments") elif name in ["reduce", "simul_reduce"]: - if len(expr.parameters) == 3: - operation, inames, red_expr = expr.parameters - if not isinstance(operation, p.Variable): - raise TypeError("operation argument to reduce() " - "must be a symbol") + if len(expr.parameters) >= 3: + operation, inames = expr.parameters[:2] + red_exprs = expr.parameters[2:] - operation = parse_reduction_op(operation.name) - return self._parse_reduction(operation, inames, self.rec(red_expr), + operation = parse_reduction_op(str(operation)) + return self._parse_reduction(operation, inames, + tuple(self.rec(red_expr) for red_expr in red_exprs), allow_simultaneous=(name == "simul_reduce")) else: raise TypeError("invalid 'reduce' calling sequence") @@ -991,12 +1019,17 @@ class FunctionToPrimitiveMapper(IdentityMapper): operation = parse_reduction_op(name) if operation: - if len(expr.parameters) != 2: + # arg_count counts arguments but not inames + if len(expr.parameters) != 1 + operation.arg_count: raise RuntimeError("invalid invocation of " - "reduction operation '%s'" % expr.function.name) - - inames, red_expr = expr.parameters - return self._parse_reduction(operation, inames, self.rec(red_expr)) + "reduction operation '%s': expected %d arguments, " + "got %d instead" % (expr.function.name, + 1 + operation.arg_count, + len(expr.parameters))) + + inames = expr.parameters[0] + red_exprs = tuple(self.rec(param) for param in expr.parameters[1:]) + return self._parse_reduction(operation, inames, red_exprs) else: return IdentityMapper.map_call(self, expr) diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 4c1e423e93e104fecd0b49a2b1ef2b4a261e38e7..b8b0cbcbf1236cdf712da998922ac238261a3e6e 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -352,28 +352,39 @@ class TypeInferenceMapper(CombineMapper): return [self.kernel.index_dtype] def map_reduction(self, expr, return_tuple=False): - rec_result = self.rec(expr.expr) - - if rec_result: - rec_result, = rec_result - result = expr.operation.result_dtypes( - self.kernel, rec_result, expr.inames) + """ + :arg return_tuple: If *True*, treat the reduction as having tuple type. + Otherwise, if *False*, the reduction must have scalar type. + """ + from loopy.symbolic import Reduction + from pymbolic.primitives import Call + + if not return_tuple and expr.is_tuple_typed: + raise LoopyError("reductions with more or fewer than one " + "return value may only be used in direct " + "assignments") + + if isinstance(expr.expr, tuple): + rec_results = [self.rec(sub_expr) for sub_expr in expr.expr] + from itertools import product + rec_results = product(*rec_results) + elif isinstance(expr.expr, Reduction): + rec_results = self.rec(expr.expr, return_tuple=return_tuple) + elif isinstance(expr.expr, Call): + rec_results = self.map_call(expr.expr, return_tuple=return_tuple) else: - result = expr.operation.result_dtypes( - self.kernel, None, expr.inames) - - if result is None: - return [] + if return_tuple: + raise LoopyError("unknown reduction type for tuple reduction: '%s'" + % type(expr.expr).__name__) + else: + rec_results = self.rec(expr.expr) if return_tuple: - return [result] - + return [expr.operation.result_dtypes(self.kernel, *rec_result) + for rec_result in rec_results] else: - if len(result) != 1 and not return_tuple: - raise LoopyError("reductions with more or fewer than one " - "return value may only be used in direct assignments") - - return [result[0]] + return [expr.operation.result_dtypes(self.kernel, rec_result)[0] + for rec_result in rec_results] # }}} @@ -617,4 +628,44 @@ def infer_unknown_types(kernel, expect_completion=False): # }}} + +# {{{ reduction expression helper + +def infer_arg_and_reduction_dtypes_for_reduction_expression( + kernel, expr, unknown_types_ok): + type_inf_mapper = TypeInferenceMapper(kernel) + import loopy as lp + + if expr.is_tuple_typed: + arg_dtypes_result = type_inf_mapper( + expr, return_tuple=True, return_dtype_set=True) + + if len(arg_dtypes_result) == 1: + arg_dtypes = arg_dtypes_result[0] + else: + if unknown_types_ok: + arg_dtypes = [lp.auto] * expr.operation.arg_count + else: + raise LoopyError("failed to determine types of accumulators for " + "reduction '%s'" % expr) + else: + try: + arg_dtypes = [type_inf_mapper(expr)] + except DependencyTypeInferenceFailure: + if unknown_types_ok: + arg_dtypes = [lp.auto] + else: + raise LoopyError("failed to determine type of accumulator for " + "reduction '%s'" % expr) + + reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes) + reduction_dtypes = tuple( + dt.with_target(kernel.target) + if dt is not lp.auto else dt + for dt in reduction_dtypes) + + return tuple(arg_dtypes), reduction_dtypes + +# }}} + # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index 1218847a7c42bd420a993d86a7534f066c2ab20e..4bb6a27267bd7b1880265bdd5b47ee676a480fb3 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1987,19 +1987,28 @@ def test_integer_reduction(ctx_factory): dtype=to_loopy_type(vtype), shape=lp.auto) - reductions = [('max', lambda x: x == np.max(var_int)), - ('min', lambda x: x == np.min(var_int)), - ('sum', lambda x: x == np.sum(var_int)), - ('product', lambda x: x == np.prod(var_int)), - ('argmax', lambda x: (x[0] == np.max(var_int) and - var_int[out[1]] == np.max(var_int))), - ('argmin', lambda x: (x[0] == np.min(var_int) and - var_int[out[1]] == np.min(var_int)))] - - for reduction, function in reductions: + from collections import namedtuple + ReductionTest = namedtuple('ReductionTest', 'kind, check, args') + + reductions = [ + ReductionTest('max', lambda x: x == np.max(var_int), args='var[k]'), + ReductionTest('min', lambda x: x == np.min(var_int), args='var[k]'), + ReductionTest('sum', lambda x: x == np.sum(var_int), args='var[k]'), + ReductionTest('product', lambda x: x == np.prod(var_int), args='var[k]'), + ReductionTest('argmax', + lambda x: ( + x[0] == np.max(var_int) and var_int[out[1]] == np.max(var_int)), + args='var[k], k'), + ReductionTest('argmin', + lambda x: ( + x[0] == np.min(var_int) and var_int[out[1]] == np.min(var_int)), + args='var[k], k') + ] + + for reduction, function, args in reductions: kstr = ("out" if 'arg' not in reduction else "out[0], out[1]") - kstr += ' = {0}(k, var[k])'.format(reduction) + kstr += ' = {0}(k, {1})'.format(reduction, args) knl = lp.make_kernel('{[k]: 0<=k