diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 752e3e4da7908132e1b9ba001451d3b86bd037f9..581f090547370ca1b8cc4752dc70e9408e6ab37c 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -658,11 +658,7 @@ class MultiAssignmentBase(InstructionBase): @memoize_method def reduction_inames(self): def map_reduction(expr, rec): - if expr.is_plain_tuple: - for sub_expr in expr.exprs: - rec(sub_expr) - else: - rec(expr.exprs) + rec(expr.exprs) for iname in expr.inames: result.add(iname) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 6d6494b5e1fa9c671c40f9c8737f9292527c9360..5ece0db1dffd2cde118bc3104b90ce6faa14a448 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -97,11 +97,7 @@ def check_reduction_iname_uniqueness(kernel): iname_to_nonsimultaneous_reduction_count = {} def map_reduction(expr, rec): - if expr.is_plain_tuple: - for sub_expr in expr.exprs: - rec(sub_expr) - else: - rec(expr.exprs) + rec(expr.exprs) for iname in expr.inames: iname_to_reduction_count[iname] = ( @@ -493,6 +489,39 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): else: return val + def expand_inner_reduction(id, expr, nresults, depends_on, within_inames, + within_inames_is_final): + from pymbolic.primitives import Call + from loopy.symbolic import Reduction + assert isinstance(expr, (Call, Reduction)) + + temp_var_names = [ + var_name_gen(id + "_arg" + str(i)) + for i in range(nresults)] + + for name in temp_var_names: + from loopy.kernel.data import TemporaryVariable, temp_var_scope + new_temporary_variables[name] = TemporaryVariable( + name=name, + shape=(), + dtype=lp.auto, + scope=temp_var_scope.PRIVATE) + + from pymbolic import var + temp_vars = tuple(var(n) for n in temp_var_names) + + call_insn = make_assignment( + id=id, + assignees=temp_vars, + expression=expr, + depends_on=depends_on, + within_inames=within_inames, + within_inames_is_final=within_inames_is_final) + + generated_insns.append(call_insn) + + return temp_vars + # }}} # {{{ sequential @@ -536,14 +565,32 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): if insn.within_inames_is_final: update_insn_iname_deps = insn.within_inames | set(expr.inames) + reduction_insn_depends_on = set([init_id]) + + if not isinstance(expr.exprs, tuple): + get_args_insn_id = insn_id_gen( + "%s_%s_get" % (insn.id, "_".join(expr.inames))) + + reduction_expr = expand_inner_reduction( + id=get_args_insn_id, + expr=expr.exprs, + nresults=nresults, + depends_on=insn.depends_on, + within_inames=update_insn_iname_deps, + within_inames_is_final=insn.within_inames_is_final) + + reduction_insn_depends_on.add(get_args_insn_id) + else: + reduction_expr = expr.exprs + reduction_insn = make_assignment( id=update_id, assignees=acc_vars, expression=expr.operation( arg_dtypes, _strip_if_scalar(acc_vars, acc_vars), - _strip_if_scalar(acc_vars, expr.exprs)), - depends_on=frozenset([init_insn.id]) | insn.depends_on, + _strip_if_scalar(acc_vars, reduction_expr)), + depends_on=frozenset(reduction_insn_depends_on) | insn.depends_on, within_inames=update_insn_iname_deps, within_inames_is_final=insn.within_inames_is_final) @@ -670,6 +717,26 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): depends_on=frozenset()) generated_insns.append(init_neutral_insn) + transfer_depends_on = set([init_neutral_id, init_id]) + + if not isinstance(expr.exprs, tuple): + get_args_insn_id = insn_id_gen( + "%s_%s_get" % (insn.id, red_iname)) + + reduction_expr = expand_inner_reduction( + id=get_args_insn_id, + expr=expr.exprs, + nresults=nresults, + depends_on=insn.depends_on, + within_inames=( + (outer_insn_inames - frozenset(expr.inames)) + | frozenset([red_iname])), + within_inames_is_final=insn.within_inames_is_final) + + transfer_depends_on.add(get_args_insn_id) + else: + reduction_expr = expr.exprs + transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname)) transfer_insn = make_assignment( id=transfer_id, @@ -679,15 +746,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): expression=expr.operation( arg_dtypes, _strip_if_scalar( - expr.exprs, + neutral_var_names, tuple(var(nvn) for nvn in neutral_var_names)), - _strip_if_scalar(expr.exprs, expr.exprs)), + _strip_if_scalar(neutral_var_names, reduction_expr)), within_inames=( (outer_insn_inames - frozenset(expr.inames)) | frozenset([red_iname])), within_inames_is_final=insn.within_inames_is_final, - depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on, - no_sync_with=frozenset([(init_id, "any")])) + depends_on=frozenset(transfer_depends_on) | insn.depends_on, + no_sync_with=frozenset( + [(insn_id, "any") for insn_id in transfer_depends_on])) generated_insns.append(transfer_insn) cur_size = 1 @@ -699,7 +767,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): istage = 0 while cur_size > 1: - new_size = cur_size // 2 assert new_size * 2 == cur_size @@ -926,6 +993,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): kernel = lp.tag_inames(kernel, new_iname_tags) + print(kernel) + kernel = ( _hackily_ensure_multi_assignment_return_values_are_scoped_private( kernel)) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 8876e295027bdb9e8ee4f0f580f4742d249511ff..89ac05f70a07fd95c88a341801a2d079d9506611 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -96,9 +96,7 @@ class IdentityMapperMixin(object): return Reduction( expr.operation, tuple(new_inames), - (tuple(self.rec(e, *args) for e in expr.exprs) - if expr.is_plain_tuple - else self.rec(expr.exprs, *args)), + self.rec(expr.exprs, *args), allow_simultaneous=expr.allow_simultaneous) def map_tagged_variable(self, expr, *args): @@ -147,11 +145,7 @@ class WalkMapper(WalkMapperBase): if not self.visit(expr): return - if expr.is_plain_tuple: - for sub_expr in expr.exprs: - self.rec(sub_expr, *args) - else: - self.rec(expr.exprs, *args) + self.rec(expr.exprs, *args) map_tagged_variable = WalkMapperBase.map_variable @@ -169,10 +163,7 @@ class CallbackMapper(CallbackMapperBase, IdentityMapper): class CombineMapper(CombineMapperBase): def map_reduction(self, expr): - if expr.is_plain_tuple: - return self.combine(self.rec(sub_expr) for sub_expr in expr.exprs) - else: - return self.rec(expr.exprs) + return self.rec(expr.exprs) map_linear_subscript = CombineMapperBase.map_subscript @@ -203,12 +194,16 @@ class StringifyMapper(StringifyMapperBase): def map_reduction(self, expr, prec): from pymbolic.mapper.stringifier import PREC_NONE + + if isinstance(expr.exprs, tuple): + inner_expr = ", ".join(self.rec(e, PREC_NONE) for e in expr.exprs) + else: + inner_expr = self.rec(expr.exprs, PREC_NONE) + return "%sreduce(%s, [%s], %s)" % ( "simul_" if expr.allow_simultaneous else "", expr.operation, ", ".join(expr.inames), - (", ".join(self.rec(e, PREC_NONE) for e in expr.exprs) - if expr.is_plain_tuple - else self.rec(expr.exprs, PREC_NONE))) + inner_expr) def map_tagged_variable(self, expr, prec): return "%s$%s" % (expr.name, expr.tag) @@ -238,15 +233,6 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase): or type(expr.operation) != type(other.operation) # noqa ): return [] - if expr.is_plain_tuple != other.is_plain_tuple: - return [] - - if expr.is_plain_tuple: - for sub_expr_l, sub_expr_r in zip(expr.exprs, other.exprs): - unis = self.rec(sub_expr_l, sub_expr_r, unis) - if not unis: - break - return unis return self.rec(expr.exprs, other.exprs, unis) @@ -281,10 +267,7 @@ class DependencyMapper(DependencyMapperBase): self.rec(child, *args) for child in expr.parameters) def map_reduction(self, expr): - if expr.is_plain_tuple: - deps = self.combine(self.rec(sub_expr) for sub_expr in expr.exprs) - else: - deps = self.rec(expr.exprs) + deps = self.rec(expr.exprs) return deps - set(p.Variable(iname) for iname in expr.inames) def map_tagged_variable(self, expr): @@ -503,8 +486,13 @@ class Reduction(p.Expression): from loopy.library.reduction import parse_reduction_op operation = parse_reduction_op(operation) - if not isinstance(exprs, tuple): - exprs = (exprs,) + from pymbolic.primitives import Call + if not isinstance(exprs, (tuple, Reduction, Call)): + from loopy.diagnostic import LoopyError + print(exprs) + raise LoopyError( + "reduction argument must be a tuple, reduction, or substitution " + "invocation, got '%s'" % type(exprs).__name__) from loopy.library.reduction import ReductionOperation assert isinstance(operation, ReductionOperation) @@ -530,12 +518,11 @@ class Reduction(p.Expression): return StringifyMapper @property - def is_plain_tuple(self): - """ - :return: True if the reduction expression is a tuple, False if otherwise - (the inner expression will still have a tuple type) - """ - return isinstance(self.exprs, tuple) + def exprs_stripped_if_scalar(self): + if isinstance(self.exprs, tuple) and len(self.exprs) == 1: + return self.exprs[0] + else: + return self.exprs @property @memoize_method @@ -1426,10 +1413,7 @@ class IndexVariableFinder(CombineMapper): return result def map_reduction(self, expr): - if expr.is_plain_tuple: - result = self.combine(self.rec(sub_expr) for sub_expr in expr.exprs) - else: - result = self.rec(expr.exprs) + result = self.rec(expr.exprs) if not (expr.inames_set & result): raise RuntimeError("reduction '%s' does not depend on " diff --git a/loopy/transform/data.py b/loopy/transform/data.py index a1948b615cc09bd7b4c50774f14c6fd61364150e..ee5ffb6bcf3cda1971261ff29d4d14eafadd00ff 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -684,7 +684,8 @@ def set_temporary_scope(kernel, temp_var_names, scope): # {{{ reduction_arg_to_subst_rule def reduction_arg_to_subst_rule( - knl, inames, insn_match=None, subst_rule_name=None, arg_number=0): + knl, inames, insn_match=None, subst_rule_name=None, + strip_if_scalar=False): if isinstance(inames, str): inames = [s.strip() for s in inames.split(",")] @@ -696,10 +697,7 @@ def reduction_arg_to_subst_rule( def map_reduction(expr, rec, nresults=1): if frozenset(expr.inames) != inames_set: - if expr.is_plain_tuple: - rec_result = tuple(rec(sub_expr) for sub_expr in expr.exprs) - else: - rec_result = rec(expr.exprs) + rec_result = rec(expr.exprs) return type(expr)( operation=expr.operation, @@ -717,27 +715,22 @@ def reduction_arg_to_subst_rule( raise LoopyError("substitution rule '%s' already exists" % my_subst_rule_name) - if not expr.is_plain_tuple: - raise NotImplemented("non-tuple reduction arguments not supported") - from loopy.kernel.data import SubstitutionRule substs[my_subst_rule_name] = SubstitutionRule( name=my_subst_rule_name, arguments=tuple(inames), - expression=expr.exprs[arg_number]) + expression=( + expr.exprs_stripped_if_scalar + if strip_if_scalar + else expr.exprs)) from pymbolic import var iname_vars = [var(iname) for iname in inames] - new_exprs = tuple(sub_expr - if i != arg_number - else var(my_subst_rule_name)(*iname_vars) - for i, sub_expr in enumerate(expr.exprs)) - return type(expr)( operation=expr.operation, inames=expr.inames, - exprs=new_exprs, + exprs=var(my_subst_rule_name)(*iname_vars), allow_simultaneous=expr.allow_simultaneous) from loopy.symbolic import ReductionCallbackMapper diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 81db51a7e6f3aa6bcef1e804325c7628e32ae095..b9a386b2b69ab1c3136f5f91075bc0129e320748 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -145,10 +145,7 @@ class _InameSplitter(RuleAwareIdentityMapper): from loopy.symbolic import Reduction return Reduction(expr.operation, tuple(new_inames), - (tuple(self.rec(sub_expr, expn_state) - for sub_expr in expr.exprs) - if expr.is_plain_tuple - else self.rec(expr.exprs, expn_state)), + self.rec(expr.exprs, expn_state), expr.allow_simultaneous) else: return super(_InameSplitter, self).map_reduction(expr, expn_state) @@ -1194,20 +1191,15 @@ class _ReductionSplitter(RuleAwareIdentityMapper): if self.direction == "in": return Reduction(expr.operation, tuple(leftover_inames), Reduction(expr.operation, tuple(self.inames), - (tuple(self.rec(sub_expr, expn_state) - for sub_expr in expr.exprs) - if expr.is_plain_tuple - else self.rec(expr.exprs, expn_state)), + self.rec(expr.exprs, expn_state), expr.allow_simultaneous), expr.allow_simultaneous) elif self.direction == "out": return Reduction(expr.operation, tuple(self.inames), Reduction(expr.operation, tuple(leftover_inames), - (tuple(self.rec(sub_expr, expn_state) - for sub_expr in expr.exprs) - if expr.is_plain_tuple - else self.rec(expr.exprs, expn_state)), - expr.allow_simultaneous)) + self.rec(expr.exprs, expn_state), + expr.allow_simultaneous), + expr.allow_simultaneous) else: assert False else: @@ -1598,16 +1590,9 @@ class _ReductionInameUniquifier(RuleAwareIdentityMapper): from loopy.symbolic import Reduction return Reduction(expr.operation, tuple(new_inames), - (tuple(self.rec( - SubstitutionMapper(make_subst_func(subst_dict))( - sub_expr), - expn_state) - for sub_expr in expr.exprs) - if expr.is_plain_tuple - else self.rec( - SubstitutionMapper(make_subst_func(subst_dict))( - expr.exprs), - expn_state)), + self.rec( + SubstitutionMapper(make_subst_func(subst_dict))(expr.exprs), + expn_state), expr.allow_simultaneous) else: return super(_ReductionInameUniquifier, self).map_reduction( diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index a19e06ecdf7c9966501ebb9600ea4e01614363f4..7e70f8c77547d39e6402d05fe56ca5dfd8a1fc64 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -59,9 +59,33 @@ def storage_axis_exprs(storage_axis_sources, args): return result +# {{{ identity mapper + +class PrecomputeIdentityMapper(RuleAwareIdentityMapper): + + def map_reduction(self, expr, expn_state): + from pymbolic.primitives import Call + new_exprs = self.rec(expr.exprs, expn_state) + + # If the substitution rule was replaced, precompute turned it into a + # scalar, but since reduction only takes tuple types we turn it into a + # tuple here. + if isinstance(expr.exprs, Call) and not isinstance(new_exprs, Call): + new_exprs = (new_exprs,) + + from loopy.symbolic import Reduction + return Reduction( + expr.operation, + expr.inames, + new_exprs, + expr.allow_simultaneous) + +# }}} + + # {{{ gather rule invocations -class RuleInvocationGatherer(RuleAwareIdentityMapper): +class RuleInvocationGatherer(PrecomputeIdentityMapper): def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): super(RuleInvocationGatherer, self).__init__(rule_mapping_context) @@ -131,7 +155,7 @@ class RuleInvocationGatherer(RuleAwareIdentityMapper): # {{{ replace rule invocation -class RuleInvocationReplacer(RuleAwareIdentityMapper): +class RuleInvocationReplacer(PrecomputeIdentityMapper): def __init__(self, rule_mapping_context, subst_name, subst_tag, within, access_descriptors, array_base_map, storage_axis_names, storage_axis_sources, diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 3c77c988261b63334f3cb8f0f84e2ea69c87901b..b6aa5d1ad055b316d68ee51e947e648df499d582 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -357,10 +357,18 @@ class TypeInferenceMapper(CombineMapper): as a tuple type. Otherwise, the number of expressions being reduced over must equal 1, and the type of the first expression is returned. """ - if expr.is_plain_tuple: + from loopy.symbolic import Reduction + from pymbolic.primitives import Call + + if isinstance(expr.exprs, tuple): rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs] + elif isinstance(expr.exprs, Reduction): + rec_results = [self.rec(expr.exprs, return_tuple=True)] + elif isinstance(expr.exprs, Call): + rec_results = [self.map_call(expr.exprs, return_tuple=return_tuple)] else: - rec_results = [self.rec(expr.exprs, return_tuple=return_tuple)] + raise LoopyError("unknown reduction type: '%s'" + % type(expr.exprs).__name__) if any(len(rec_result) == 0 for rec_result in rec_results): return [] @@ -629,7 +637,12 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression( type_inf_mapper = TypeInferenceMapper(kernel) import loopy as lp - for sub_expr in expr.exprs: + if isinstance(expr.exprs, tuple): + exprs = expr.exprs + else: + exprs = (expr.exprs,) + + for sub_expr in exprs: try: arg_dtype = type_inf_mapper(sub_expr) except DependencyTypeInferenceFailure: diff --git a/test/test_loopy.py b/test/test_loopy.py index 851a7f0762fcec3ccbb55399e183f5fb51322ac1..d5d1a1f31ba5ad9ecaeedeb92b1188d5208e37c6 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2108,6 +2108,28 @@ def test_barrier_insertion_near_bottom_of_loop(): assert_barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1]) +def test_multi_argument_reduction_type_inference(): + from loopy.type_inference import TypeInferenceMapper + from loopy.library.reduction import SegmentedSumReductionOperation + from loopy.types import to_loopy_type + op = SegmentedSumReductionOperation() + + knl = lp.make_kernel("{[i]: 0<=i<10}", "") + + int32 = to_loopy_type(np.int32) + + expr = lp.symbolic.Reduction( + operation=op, + inames=("i",), + exprs=op.neutral_element(int32, int32), + allow_simultaneous=True) + + t_inf_mapper = TypeInferenceMapper(knl) + + print(t_inf_mapper(expr, return_tuple=True)) + 1/0 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])