From 059fef724ffefbed3bbff7976e68a0016b8dd1b8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 4 Feb 2022 01:08:03 -0600 Subject: [PATCH] Make realize_reduction actually recursive (closes gh-533) --- loopy/transform/realize_reduction.py | 849 +++++++++++++++------------ test/test_scan.py | 5 - 2 files changed, 476 insertions(+), 378 deletions(-) diff --git a/loopy/transform/realize_reduction.py b/loopy/transform/realize_reduction.py index c02c05fdf..67aa627f8 100644 --- a/loopy/transform/realize_reduction.py +++ b/loopy/transform/realize_reduction.py @@ -24,8 +24,9 @@ THE SOFTWARE. """ -from dataclasses import dataclass -from typing import Tuple, Dict, Callable, List, Optional, Set, Sequence +from dataclasses import dataclass, replace +from typing import (Tuple, Dict, Callable, List, Optional, Set, Sequence, + FrozenSet) import logging logger = logging.getLogger(__name__) @@ -33,10 +34,11 @@ logger = logging.getLogger(__name__) from pytools import memoize_on_first_arg from pytools.tag import Tag import islpy as isl +from pymbolic.primitives import Expression + +from pyrsistent import PMap from loopy.kernel.data import make_assignment -from loopy.kernel.tools import ( - kernel_has_global_barriers, find_most_recent_global_barrier) from loopy.symbolic import ReductionCallbackMapper from loopy.translation_unit import TranslationUnit from loopy.kernel.function_interface import CallableKernel @@ -51,30 +53,34 @@ from loopy.transform.instruction import replace_instruction_ids_in_insn # {{{ reduction realization context +@dataclass +class _ChangeFlag: + changes_made: bool + + @dataclass(frozen=True) class _ReductionRealizationContext: # {{{ read-only + mapper: "RealizeReductionCallbackMapper" + force_scan: bool automagic_scans_ok: bool unknown_types_ok: bool - # FIXME: This feels like a broken-by-design concept + # FIXME: This feels like a broken-by-design concept. force_outer_iname_for_scan: Optional[str] # We use the original kernel for a number of lookups whose value # we do not change and which might be already cached on it. orig_kernel: LoopKernel - kernel: LoopKernel - # FIXME: This shouldn't be here. We might generate multiple instructions - # in a nested manner. Why should the 'top-level' instruction be special? - insn: InstructionBase + id_prefix: str # }}} - # {{{ internally mutable + # {{{ internally mutable, same across entire recursion insn_id_gen: Callable[[str], str] var_name_gen: Callable[[str], str] @@ -83,28 +89,84 @@ class _ReductionRealizationContext: additional_insns: List[InstructionBase] domains: List[isl.BasicSet] additional_iname_tags: Dict[str, Sequence[Tag]] + # list only to facilitate mutation + boxed_callables_table: List[PMap] # FIXME: This is a broken-by-design concept. Local-parallel scans emit a # reduction internally. This serves to avoid force_scan acting on that # reduction. inames_added_for_scan: Set[str] - # FIXME: Clarify how these relate to recursively generated instructions. - new_insn_add_depends_on: Set[str] - new_insn_add_no_sync_with: Set[Tuple[str, str]] - new_insn_add_within_inames: Set[str] + # }}} + + # {{{ surrounding instruction, read-only (different at each recursive level) + + # These are attributes from 'surrounding' instruction, for generated + # instructions to potentially inherit. + surrounding_within_inames: FrozenSet[str] + surrounding_depends_on: FrozenSet[str] + surrounding_no_sync_with: FrozenSet[Tuple[str, str]] + surrounding_predicates: FrozenSet[Expression] # }}} - # {{{ change tracking + # {{{ surrounding instruction, internally mutable + # (different at each recursive level) + + # These are requested additions to attributes of the surrounding instruction. + + # FIXME add_within_inames seems broken by design. + surrounding_insn_add_within_inames: Set[str] + + surrounding_insn_add_depends_on: Set[str] + surrounding_insn_add_no_sync_with: Set[Tuple[str, str]] + + # }}} + + # {{{ change tracking (same across entire recursion) + + _change_flag: _ChangeFlag - were_changes_made: bool + @property + def were_changes_made(self): + return self._change_flag.changes_made def changes_made(self): - object.__setattr__(self, "were_changes_made", True) + self._change_flag.changes_made = True # }}} + def new_subinstruction(self, *, within_inames, depends_on, + no_sync_with=None, predicates=None): + if no_sync_with is None: + no_sync_with = self.surrounding_no_sync_with + if predicates is None: + predicates = self.surrounding_predicates + + return replace(self, + surrounding_within_inames=within_inames, + surrounding_depends_on=depends_on, + surrounding_no_sync_with=no_sync_with, + surrounding_predicates=predicates, + + surrounding_insn_add_within_inames=set(), + surrounding_insn_add_depends_on=set(), + surrounding_insn_add_no_sync_with=set()) + + def get_insn_kwargs(self): + return dict( + within_inames=( + self.surrounding_within_inames + | frozenset(self.surrounding_insn_add_within_inames)), + within_inames_is_final=True, + depends_on=( + self.surrounding_depends_on + | frozenset(self.surrounding_insn_add_depends_on)), + no_sync_with=( + self.surrounding_no_sync_with + | frozenset(self.surrounding_insn_add_no_sync_with)), + predicates=self.surrounding_predicates) + # }}} @@ -117,7 +179,7 @@ class _InameClassification: nonlocal_parallel: Tuple[str, ...] -def _classify_reduction_inames(kernel, inames): +def _classify_reduction_inames(red_realize_ctx, inames): sequential = [] local_par = [] nonlocal_par = [] @@ -127,7 +189,10 @@ def _classify_reduction_inames(kernel, inames): ConcurrentTag, filter_iname_tags_by_type) for iname in inames: - iname_tags = kernel.iname_tags(iname) + try: + iname_tags = red_realize_ctx.additional_iname_tags[iname] + except KeyError: + iname_tags = red_realize_ctx.kernel.iname_tags(iname) if filter_iname_tags_by_type(iname_tags, (UnrollTag, UnrolledIlpTag)): # These are nominally parallel, but we can live with @@ -333,8 +398,13 @@ def _try_infer_scan_candidate_from_expr( "(sweep iname: '%s', scan iname: '%s'): %s" % (expr, sweep_iname, scan_iname, v)) - return _ScanCandidateParameters(sweep_iname, scan_iname, sweep_lower_bound, - sweep_upper_bound, scan_lower_bound, stride) + return _ScanCandidateParameters( + sweep_iname=sweep_iname, + scan_iname=scan_iname, + sweep_lower_bound=sweep_lower_bound, + sweep_upper_bound=sweep_upper_bound, + scan_lower_bound=scan_lower_bound, + stride=stride) def _try_infer_sweep_iname(domain, scan_iname, candidate_inames): @@ -499,15 +569,16 @@ def _get_domain_with_iname_as_param(domain, iname): dim_type.set, iname_idx, 1) -def _create_domain_for_sweep_tracking(orig_domain, - tracking_iname, sweep_iname, sweep_min_value, scan_min_value, stride): +def _create_domain_for_sweep_tracking(orig_domain, tracking_iname, scan_param): + sp = scan_param + dim_type = isl.dim_type subd = isl.BasicSet.universe(orig_domain.params().space) # Add tracking_iname and sweep iname. - subd = _add_params_to_domain(subd, (sweep_iname, tracking_iname)) + subd = _add_params_to_domain(subd, (sp.sweep_iname, tracking_iname)) # Here we realize the domain: # @@ -526,11 +597,11 @@ def _create_domain_for_sweep_tracking(orig_domain, # affs = isl.affs_from_space(subd.space) - subd &= (affs[tracking_iname] - scan_min_value).ge_set(affs[0]) - subd &= (affs[tracking_iname] - scan_min_value)\ - .le_set(stride * (affs[sweep_iname] - sweep_min_value)) - subd &= (affs[tracking_iname] - scan_min_value)\ - .gt_set(stride * (affs[sweep_iname] - sweep_min_value - 1)) + subd &= (affs[tracking_iname] - sp.scan_lower_bound).ge_set(affs[0]) + subd &= (affs[tracking_iname] - sp.scan_lower_bound)\ + .le_set(sp.stride * (affs[sp.sweep_iname] - sp.sweep_lower_bound)) + subd &= (affs[tracking_iname] - sp.scan_lower_bound)\ + .gt_set(sp.stride * (affs[sp.sweep_iname] - sp.sweep_lower_bound - 1)) # Move tracking_iname into a set dim (NOT sweep iname). subd = subd.move_dims( @@ -539,7 +610,7 @@ def _create_domain_for_sweep_tracking(orig_domain, # Simplify (maybe). orig_domain_with_sweep_param = ( - _get_domain_with_iname_as_param(orig_domain, sweep_iname)) + _get_domain_with_iname_as_param(orig_domain, sp.sweep_iname)) subd = subd.gist_params(orig_domain_with_sweep_param.params()) subd, = subd.get_basic_sets() @@ -738,39 +809,56 @@ def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel): # {{{ RealizeReductionCallbackMapper class RealizeReductionCallbackMapper(ReductionCallbackMapper): - def __init__(self, callback, callables_table): + def __init__(self, callback): super().__init__(callback) - self.callables_table = callables_table def map_reduction(self, expr, **kwargs): - result, self.callables_table = self.callback(expr, rec=self.rec, - **kwargs) - return result + return self.callback(expr, **kwargs) - def map_if(self, expr, *, - callables_table, red_realize_ctx, - guarding_predicates, nresults): + def map_if(self, expr, *, red_realize_ctx, nresults): + common_kwargs = dict(nresults=nresults) - common_kwargs = dict( - callables_table=callables_table, - red_realize_ctx=red_realize_ctx, - nresults=nresults) + # {{{ generate code for condition + rrc_cond = replace(red_realize_ctx, + surrounding_insn_add_depends_on=set(), + surrounding_insn_add_no_sync_with=set(), + surrounding_insn_add_within_inames=set()) import pymbolic.primitives as prim rec_cond = self.rec( expr.condition, - guarding_predicates=guarding_predicates, + red_realize_ctx=rrc_cond, **common_kwargs) + assert not rrc_cond.surrounding_insn_add_no_sync_with + assert not rrc_cond.surrounding_insn_add_within_inames + + cond_dep_on = rrc_cond.surrounding_insn_add_depends_on + red_realize_ctx.surrounding_insn_add_depends_on.update(cond_dep_on) + + # }}} + return prim.If(rec_cond, self.rec(expr.then, - guarding_predicates=( - guarding_predicates - | frozenset([rec_cond])), + red_realize_ctx=replace( + red_realize_ctx, + surrounding_depends_on=( + red_realize_ctx.surrounding_depends_on + | cond_dep_on), + surrounding_predicates=( + red_realize_ctx.surrounding_predicates + | frozenset([rec_cond]) + )), **common_kwargs), self.rec(expr.else_, - guarding_predicates=( - guarding_predicates - | frozenset([prim.LogicalNot(rec_cond)])), + red_realize_ctx=replace( + red_realize_ctx, + surrounding_depends_on=( + red_realize_ctx.surrounding_depends_on + | cond_dep_on), + surrounding_predicates=( + red_realize_ctx.surrounding_predicates + | frozenset([prim.LogicalNot(rec_cond)]) + )), **common_kwargs)) # }}} @@ -788,13 +876,10 @@ def _strip_if_scalar(reference, val): def _preprocess_scan_arguments( red_realize_ctx, expr, nresults, scan_iname, track_iname, - newly_generated_insn_id_set, - insn_id_gen): + newly_generated_insn_id_set): """Does iname substitution within scan arguments and returns a set of values suitable to be passed to the binary op. Returns a tuple.""" - insn = red_realize_ctx.insn - if nresults > 1: inner_expr = expr @@ -802,21 +887,21 @@ def _preprocess_scan_arguments( # the arguments in order to pass them to the binary op - so we expand # items that are not "plain" tuples here. if not isinstance(inner_expr, tuple): - get_args_insn_id = insn_id_gen( - "{}_{}_get".format(insn.id, "_".join(expr.inames))) + get_args_insn_id = red_realize_ctx.insn_id_gen( + f"{red_realize_ctx.id_prefix}_{'_'.join(expr.inames)}_get") inner_expr = expand_inner_reduction( red_realize_ctx=red_realize_ctx, id=get_args_insn_id, expr=inner_expr, nresults=nresults, - depends_on=insn.depends_on, - within_inames=insn.within_inames | expr.inames, - within_inames_is_final=insn.within_inames_is_final, - predicates=insn.predicates, + depends_on=red_realize_ctx.surrounding_depends_on, + within_inames=red_realize_ctx.surrounding_within_inames, + predicates=red_realize_ctx.surrounding_predicates, ) - newly_generated_insn_id_set.add(get_args_insn_id) + newly_generated_insn_id_set = ( + newly_generated_insn_id_set | frozenset({get_args_insn_id})) updated_inner_exprs = tuple( replace_var_within_expr( @@ -829,14 +914,13 @@ def _preprocess_scan_arguments( red_realize_ctx.kernel, red_realize_ctx.var_name_gen, expr, scan_iname, track_iname),) - return updated_inner_exprs + return updated_inner_exprs, newly_generated_insn_id_set # }}} def expand_inner_reduction( - red_realize_ctx, id, expr, nresults, depends_on, within_inames, - within_inames_is_final, predicates): + red_realize_ctx, id, expr, nresults, depends_on, within_inames, predicates): # FIXME: use _make_temporaries from pymbolic.primitives import Call from loopy.symbolic import Reduction @@ -862,7 +946,7 @@ def expand_inner_reduction( expression=expr, depends_on=depends_on, within_inames=within_inames, - within_inames_is_final=within_inames_is_final, + within_inames_is_final=True, predicates=predicates) red_realize_ctx.additional_insns.append(call_insn) @@ -872,13 +956,8 @@ def expand_inner_reduction( # {{{ reduction type: sequential -def map_reduction_seq( - red_realize_ctx, expr, rec, callables_table, nresults, arg_dtypes, - reduction_dtypes, guarding_predicates): +def map_reduction_seq(red_realize_ctx, expr, nresults, arg_dtypes, reduction_dtypes): orig_kernel = red_realize_ctx.orig_kernel - insn = red_realize_ctx.insn - - outer_insn_inames = red_realize_ctx.insn.within_inames acc_var_names = _make_temporaries( red_realize_ctx=red_realize_ctx, @@ -888,37 +967,24 @@ def map_reduction_seq( dtypes=reduction_dtypes, address_space=AddressSpace.PRIVATE) - init_insn_depends_on = frozenset() - - # check first that the original kernel had global barriers - # if not, we don't need to check. Since the function - # kernel_has_global_barriers is cached, we don't do - # extra work compared to not checking. - # FIXME: Explain why we care about global barriers here - if kernel_has_global_barriers(orig_kernel): - global_barrier = find_most_recent_global_barrier( - red_realize_ctx.kernel, - insn.id) - - if global_barrier is not None: - init_insn_depends_on |= frozenset([global_barrier]) - from pymbolic import var acc_vars = tuple(var(n) for n in acc_var_names) init_id = red_realize_ctx.insn_id_gen( - "{}_{}_init".format(insn.id, "_".join(expr.inames))) + f"{red_realize_ctx.id_prefix}_{'_'.join(expr.inames)}_init") - expression, callables_table = expr.operation.neutral_element( - *arg_dtypes, callables_table=callables_table, - target=red_realize_ctx.orig_kernel.target) + expression, red_realize_ctx.boxed_callables_table[0] = \ + expr.operation.neutral_element( + *arg_dtypes, + callables_table=red_realize_ctx.boxed_callables_table[0], + target=red_realize_ctx.orig_kernel.target) init_insn = make_assignment( id=init_id, assignees=acc_vars, - within_inames=outer_insn_inames - frozenset(expr.inames), - within_inames_is_final=insn.within_inames_is_final, - depends_on=init_insn_depends_on, + within_inames=red_realize_ctx.surrounding_within_inames, + within_inames_is_final=True, + depends_on=frozenset(), expression=expression, # Do not inherit predicates: Those might read variables @@ -934,61 +1000,60 @@ def map_reduction_seq( red_realize_ctx.additional_insns.append(init_insn) update_id = red_realize_ctx.insn_id_gen( - based_on="{}_{}_update".format(insn.id, "_".join(expr.inames))) + based_on=f"{red_realize_ctx.id_prefix}_{'_'.join(expr.inames)}_update") - update_insn_iname_deps = insn.within_inames | set(expr.inames) - if insn.within_inames_is_final: - update_insn_iname_deps = insn.within_inames | set(expr.inames) + update_red_realize_ctx = red_realize_ctx.new_subinstruction( + within_inames=( + red_realize_ctx.surrounding_within_inames + | frozenset(expr.inames)), + depends_on=( + frozenset({init_id}) + | red_realize_ctx.surrounding_depends_on)) - reduction_insn_depends_on = {init_id} + reduction_expr = red_realize_ctx.mapper( + expr.expr, red_realize_ctx=update_red_realize_ctx, + nresults=1) # In the case of a multi-argument reduction, we need a name for each of # the arguments in order to pass them to the binary op - so we expand # items that are not "plain" tuples here. - if nresults > 1 and not isinstance(expr.expr, tuple): + if nresults > 1 and not isinstance(reduction_expr, tuple): get_args_insn_id = red_realize_ctx.insn_id_gen( - "{}_{}_get".format(insn.id, "_".join(expr.inames))) + f"{red_realize_ctx.id_prefix}_{'_'.join(expr.inames)}_get") reduction_expr = expand_inner_reduction( red_realize_ctx=red_realize_ctx, id=get_args_insn_id, - expr=expr.expr, + expr=reduction_expr, nresults=nresults, - depends_on=insn.depends_on, - within_inames=update_insn_iname_deps, - within_inames_is_final=insn.within_inames_is_final, - predicates=guarding_predicates, + depends_on=red_realize_ctx.surrounding_depends_on, + within_inames=update_red_realize_ctx.surrounding_within_inames, + predicates=red_realize_ctx.surrounding_predicates, ) - reduction_insn_depends_on.add(get_args_insn_id) - else: - reduction_expr = expr.expr + update_red_realize_ctx.surrounding_insn_add_depends_on.add(get_args_insn_id) - expression, callables_table = expr.operation( + expression, red_realize_ctx.boxed_callables_table[0] = expr.operation( arg_dtypes, _strip_if_scalar(acc_vars, acc_vars), reduction_expr, - callables_table, + red_realize_ctx.boxed_callables_table[0], orig_kernel.target) reduction_insn = make_assignment( id=update_id, assignees=acc_vars, expression=expression, - depends_on=frozenset(reduction_insn_depends_on) | insn.depends_on, - within_inames=update_insn_iname_deps, - within_inames_is_final=insn.within_inames_is_final, - predicates=guarding_predicates,) + **update_red_realize_ctx.get_insn_kwargs()) red_realize_ctx.additional_insns.append(reduction_insn) - - red_realize_ctx.new_insn_add_depends_on.add(reduction_insn.id) + red_realize_ctx.surrounding_insn_add_depends_on.add(reduction_insn.id) if nresults == 1: assert len(acc_vars) == 1 - return acc_vars[0], callables_table + return acc_vars[0] else: - return acc_vars, callables_table + return acc_vars # }}} @@ -1024,30 +1089,26 @@ def _make_slab_set_from_range(iname, lbound, ubound): return bs -def map_reduction_local( - red_realize_ctx, - expr, rec, callables_table, nresults, arg_dtypes, - reduction_dtypes, guarding_predicates): +def map_reduction_local(red_realize_ctx, expr, nresults, arg_dtypes, + reduction_dtypes): orig_kernel = red_realize_ctx.orig_kernel - insn = red_realize_ctx.insn red_iname, = expr.inames size = _get_int_iname_size(orig_kernel, red_iname) - outer_insn_inames = insn.within_inames - from loopy.kernel.data import LocalInameTagBase - outer_local_inames = tuple(oiname for oiname in outer_insn_inames + surrounding_local_inames = tuple( + oiname for oiname in red_realize_ctx.surrounding_within_inames if orig_kernel.iname_tags_of_type(oiname, LocalInameTagBase)) from pymbolic import var outer_local_iname_vars = tuple( - var(oiname) for oiname in outer_local_inames) + var(oiname) for oiname in surrounding_local_inames) outer_local_iname_sizes = tuple( _get_int_iname_size(orig_kernel, oiname) - for oiname in outer_local_inames) + for oiname in surrounding_local_inames) neutral_var_names = _make_temporaries( red_realize_ctx=red_realize_ctx, @@ -1079,19 +1140,22 @@ def map_reduction_local( # }}} - base_iname_deps = outer_insn_inames - frozenset(expr.inames) - - neutral, callables_table = expr.operation.neutral_element(*arg_dtypes, - callables_table=callables_table, target=orig_kernel.target) - init_id = red_realize_ctx.insn_id_gen(f"{insn.id}_{red_iname}_init") + neutral, red_realize_ctx.boxed_callables_table[0] = \ + expr.operation.neutral_element(*arg_dtypes, + callables_table=red_realize_ctx.boxed_callables_table[0], + target=orig_kernel.target) + init_id = red_realize_ctx.insn_id_gen( + f"{red_realize_ctx.id_prefix}_{red_iname}_init") init_insn = make_assignment( id=init_id, assignees=tuple( acc_var[outer_local_iname_vars + (var(base_exec_iname),)] for acc_var in acc_vars), expression=neutral, - within_inames=base_iname_deps | frozenset([base_exec_iname]), - within_inames_is_final=insn.within_inames_is_final, + within_inames=( + red_realize_ctx.surrounding_within_inames + | frozenset([base_exec_iname])), + within_inames_is_final=True, depends_on=frozenset(), # Do not inherit predicates: Those might read variables # that may not yet be set, and we don't have a great way @@ -1105,52 +1169,65 @@ def map_reduction_local( red_realize_ctx.additional_insns.append(init_insn) init_neutral_id = red_realize_ctx.insn_id_gen( - f"{insn.id}_{red_iname}_init_neutral") + f"{red_realize_ctx.id_prefix}_{red_iname}_init_neutral") init_neutral_insn = make_assignment( id=init_neutral_id, assignees=tuple(var(nvn) for nvn in neutral_var_names), expression=neutral, - within_inames=base_iname_deps | frozenset([base_exec_iname]), - within_inames_is_final=insn.within_inames_is_final, + within_inames=( + red_realize_ctx.surrounding_within_inames + | frozenset([base_exec_iname])), + within_inames_is_final=True, depends_on=frozenset(), - predicates=guarding_predicates, + predicates=red_realize_ctx.surrounding_predicates, ) red_realize_ctx.additional_insns.append(init_neutral_insn) transfer_depends_on = {init_neutral_id, init_id} + transfer_red_realize_ctx = red_realize_ctx.new_subinstruction( + within_inames=( + red_realize_ctx.surrounding_within_inames + | frozenset([red_iname])), + depends_on=( + red_realize_ctx.surrounding_depends_on + | frozenset([init_id, init_neutral_id])), + no_sync_with=( + red_realize_ctx.surrounding_no_sync_with + | frozenset([(init_id, "any")]))) + + reduction_expr = red_realize_ctx.mapper( + expr.expr, red_realize_ctx=transfer_red_realize_ctx, + nresults=1) + # In the case of a multi-argument reduction, we need a name for each of # the arguments in order to pass them to the binary op - so we expand # items that are not "plain" tuples here. - if nresults > 1 and not isinstance(expr.expr, tuple): + if nresults > 1 and not isinstance(reduction_expr, tuple): get_args_insn_id = red_realize_ctx.insn_id_gen( - f"{insn.id}_{red_iname}_get") + f"{red_realize_ctx.id_prefix}_{red_iname}_get") reduction_expr = expand_inner_reduction( red_realize_ctx=red_realize_ctx, id=get_args_insn_id, - expr=expr.expr, + expr=reduction_expr, nresults=nresults, - depends_on=insn.depends_on, - within_inames=( - (outer_insn_inames - frozenset(expr.inames)) - | frozenset([red_iname])), - within_inames_is_final=insn.within_inames_is_final, - predicates=guarding_predicates, + depends_on=red_realize_ctx.surrounding_depends_on, + within_inames=transfer_red_realize_ctx.surrounding_within_inames, + predicates=red_realize_ctx.surrounding_predicates, ) transfer_depends_on.add(get_args_insn_id) - else: - reduction_expr = expr.expr - transfer_id = red_realize_ctx.insn_id_gen(f"{insn.id}_{red_iname}_transfer") - expression, callables_table = expr.operation( + transfer_id = red_realize_ctx.insn_id_gen( + f"{red_realize_ctx.id_prefix}_{red_iname}_transfer") + expression, red_realize_ctx.boxed_callables_table[0] = expr.operation( arg_dtypes, _strip_if_scalar( neutral_var_names, tuple(var(nvn) for nvn in neutral_var_names)), reduction_expr, - callables_table, + red_realize_ctx.boxed_callables_table[0], orig_kernel.target) transfer_insn = make_assignment( id=transfer_id, @@ -1158,14 +1235,7 @@ def map_reduction_local( acc_var[outer_local_iname_vars + (var(red_iname),)] for acc_var in acc_vars), expression=expression, - within_inames=( - (outer_insn_inames - frozenset(expr.inames)) - | frozenset([red_iname])), - within_inames_is_final=insn.within_inames_is_final, - depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on, - no_sync_with=frozenset([(init_id, "any")]), - predicates=insn.predicates, - ) + **transfer_red_realize_ctx.get_insn_kwargs()) red_realize_ctx.additional_insns.append(transfer_insn) cur_size = 1 @@ -1193,7 +1263,7 @@ def map_reduction_local( stage_id = red_realize_ctx.insn_id_gen( "red_%s_stage_%d" % (red_iname, istage)) - expression, callables_table = expr.operation( + expression, red_realize_ctx.boxed_callables_table[0] = expr.operation( arg_dtypes, _strip_if_scalar(acc_vars, tuple( acc_var[ @@ -1204,7 +1274,7 @@ def map_reduction_local( outer_local_iname_vars + ( var(stage_exec_iname) + new_size,)] for acc_var in acc_vars)), - callables_table, + red_realize_ctx.boxed_callables_table[0], orig_kernel.target) stage_insn = make_assignment( @@ -1214,10 +1284,11 @@ def map_reduction_local( for acc_var in acc_vars), expression=expression, within_inames=( - base_iname_deps | frozenset([stage_exec_iname])), - within_inames_is_final=insn.within_inames_is_final, + red_realize_ctx.surrounding_within_inames + | frozenset([stage_exec_iname])), + within_inames_is_final=True, depends_on=frozenset([prev_id]), - predicates=insn.predicates, + predicates=red_realize_ctx.surrounding_predicates, ) red_realize_ctx.additional_insns.append(stage_insn) @@ -1227,17 +1298,17 @@ def map_reduction_local( bound = cur_size istage += 1 - red_realize_ctx.new_insn_add_depends_on.add(prev_id) - red_realize_ctx.new_insn_add_no_sync_with.add((prev_id, "any")) - red_realize_ctx.new_insn_add_within_inames.add( + red_realize_ctx.surrounding_insn_add_depends_on.add(prev_id) + red_realize_ctx.surrounding_insn_add_no_sync_with.add((prev_id, "any")) + red_realize_ctx.surrounding_insn_add_within_inames.add( stage_exec_iname or base_exec_iname) if nresults == 1: assert len(acc_vars) == 1 - return acc_vars[0][outer_local_iname_vars + (0,)], callables_table + return acc_vars[0][outer_local_iname_vars + (0,)] else: return [acc_var[outer_local_iname_vars + (0,)] for acc_var in - acc_vars], callables_table + acc_vars] # }}} @@ -1246,16 +1317,17 @@ def map_reduction_local( @memoize_on_first_arg def _get_or_add_sweep_tracking_iname_and_domain( red_realize_ctx, - scan_iname, sweep_iname, sweep_min_value, scan_min_value, stride, + scan_param, tracking_iname): kernel = red_realize_ctx.kernel - domain = kernel.get_inames_domain(frozenset((scan_iname, sweep_iname))) + domain = kernel.get_inames_domain( + frozenset((scan_param.scan_iname, scan_param.sweep_iname))) red_realize_ctx.inames_added_for_scan.add(tracking_iname) - new_domain = _create_domain_for_sweep_tracking(domain, - tracking_iname, sweep_iname, sweep_min_value, scan_min_value, stride) + new_domain = _create_domain_for_sweep_tracking( + domain, tracking_iname, scan_param) _insert_subdomain_into_domain_tree(kernel, red_realize_ctx.domains, new_domain) @@ -1268,6 +1340,9 @@ def replace_var_within_expr(kernel, var_name_gen, expr, from_var, to_var): from loopy.symbolic import ( SubstitutionRuleMappingContext, RuleAwareSubstitutionMapper) + # FIXME: This is broken. SubstitutionRuleMappingContext produces a new + # kernel (via finish_kernel) with new subst rules. These get dropped on the + # floor here. rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, var_name_gen) @@ -1302,28 +1377,21 @@ def _make_temporaries( # {{{ reduction type: sequential scan -def map_scan_seq( - red_realize_ctx, - expr, rec, callables_table, nresults, arg_dtypes, - reduction_dtypes, sweep_iname, scan_iname, sweep_min_value, - scan_min_value, stride, guarding_predicates): - insn = red_realize_ctx.insn - - outer_insn_inames = insn.within_inames +def map_scan_seq(red_realize_ctx, expr, nresults, arg_dtypes, + reduction_dtypes, scan_param): track_iname = red_realize_ctx.var_name_gen( "{sweep_iname}__seq_scan" - .format(sweep_iname=sweep_iname)) + .format(sweep_iname=scan_param.sweep_iname)) _get_or_add_sweep_tracking_iname_and_domain( - red_realize_ctx, - scan_iname, sweep_iname, sweep_min_value, scan_min_value, - stride, track_iname) + red_realize_ctx, scan_param, track_iname) + red_realize_ctx.additional_iname_tags[track_iname] = frozenset() from loopy.kernel.data import AddressSpace acc_var_names = _make_temporaries( red_realize_ctx=red_realize_ctx, - name_based_on="acc_" + scan_iname, + name_based_on="acc_" + scan_param.scan_iname, nvars=nresults, shape=(), dtypes=reduction_dtypes, @@ -1333,28 +1401,22 @@ def map_scan_seq( acc_vars = tuple(var(n) for n in acc_var_names) init_id = red_realize_ctx.insn_id_gen( - "{}_{}_init".format(insn.id, "_".join(expr.inames))) + f"{red_realize_ctx.id_prefix}_{'_'.join(expr.inames)}_init") init_insn_depends_on = frozenset() - # FIXME: Explain why we care about global barriers here - if kernel_has_global_barriers(red_realize_ctx.orig_kernel): - global_barrier = find_most_recent_global_barrier( - red_realize_ctx.kernel, insn.id) - - if global_barrier is not None: - init_insn_depends_on |= frozenset([global_barrier]) - - expression, callables_table = expr.operation.neutral_element( - *arg_dtypes, callables_table=callables_table, - target=red_realize_ctx.orig_kernel.target) + expression, red_realize_ctx.boxed_callables_table[0] = \ + expr.operation.neutral_element(*arg_dtypes, + callables_table=red_realize_ctx.boxed_callables_table[0], + target=red_realize_ctx.orig_kernel.target) init_insn = make_assignment( id=init_id, assignees=acc_vars, - within_inames=outer_insn_inames - frozenset( - (sweep_iname,) + expr.inames), - within_inames_is_final=insn.within_inames_is_final, + within_inames=( + red_realize_ctx.surrounding_within_inames + - frozenset((scan_param.sweep_iname,) + expr.inames)), + within_inames_is_final=True, depends_on=init_insn_depends_on, expression=expression, # Do not inherit predicates: Those might read variables @@ -1369,78 +1431,86 @@ def map_scan_seq( red_realize_ctx.additional_insns.append(init_insn) - update_insn_depends_on = {init_insn.id} | insn.depends_on + scan_insn_depends_on = {init_insn.id} | red_realize_ctx.surrounding_depends_on - updated_inner_exprs = _preprocess_scan_arguments( - red_realize_ctx, - expr.expr, nresults, - scan_iname, track_iname, update_insn_depends_on, - insn_id_gen=red_realize_ctx.insn_id_gen) + scan_red_realize_ctx = red_realize_ctx.new_subinstruction( + within_inames=( + red_realize_ctx.surrounding_within_inames + | frozenset({scan_param.scan_iname})), + depends_on=red_realize_ctx.surrounding_depends_on) - update_id = red_realize_ctx.insn_id_gen( - based_on="{}_{}_update".format(insn.id, "_".join(expr.inames))) + reduction_expr = red_realize_ctx.mapper( + expr.expr, red_realize_ctx=scan_red_realize_ctx, + nresults=1) + + updated_inner_exprs, scan_insn_depends_on = _preprocess_scan_arguments( + scan_red_realize_ctx, + reduction_expr, nresults, + scan_param.scan_iname, track_iname, scan_insn_depends_on) - update_insn_iname_deps = insn.within_inames | {track_iname} - if insn.within_inames_is_final: - update_insn_iname_deps = insn.within_inames | {track_iname} + scan_id = red_realize_ctx.insn_id_gen( + based_on=f"{red_realize_ctx.id_prefix}_{'_'.join(expr.inames)}_scan") - expression, callables_table = expr.operation( + expression, red_realize_ctx.boxed_callables_table[0] = expr.operation( arg_dtypes, _strip_if_scalar(acc_vars, acc_vars), _strip_if_scalar(acc_vars, updated_inner_exprs), - callables_table, + red_realize_ctx.boxed_callables_table[0], red_realize_ctx.orig_kernel.target) scan_insn = make_assignment( - id=update_id, + id=scan_id, assignees=acc_vars, expression=expression, - depends_on=frozenset(update_insn_depends_on), - within_inames=update_insn_iname_deps, - no_sync_with=insn.no_sync_with, - within_inames_is_final=insn.within_inames_is_final, - predicates=guarding_predicates, + within_inames=( + red_realize_ctx.surrounding_within_inames + | frozenset( + scan_red_realize_ctx.surrounding_insn_add_within_inames) + | {track_iname}), + depends_on=( + frozenset(scan_insn_depends_on) + | frozenset(scan_red_realize_ctx.surrounding_insn_add_depends_on) + ), + no_sync_with=( + red_realize_ctx.surrounding_no_sync_with + | frozenset(scan_red_realize_ctx.surrounding_insn_add_no_sync_with) + ), + within_inames_is_final=True, + predicates=red_realize_ctx.surrounding_predicates, ) red_realize_ctx.additional_insns.append(scan_insn) - red_realize_ctx.new_insn_add_depends_on.add(scan_insn.id) + red_realize_ctx.surrounding_insn_add_depends_on.add(scan_insn.id) if nresults == 1: assert len(acc_vars) == 1 - return acc_vars[0], callables_table + return acc_vars[0] else: - return acc_vars, callables_table + return acc_vars # }}} # {{{ reduction type: local-parallel scan -def map_scan_local( - red_realize_ctx, - expr, rec, callables_table, nresults, arg_dtypes, - reduction_dtypes, sweep_iname, scan_iname, sweep_min_value, - scan_min_value, stride, guarding_predicates): +def map_scan_local(red_realize_ctx, expr, nresults, arg_dtypes, + reduction_dtypes, scan_param): orig_kernel = red_realize_ctx.orig_kernel - insn = red_realize_ctx.insn - scan_size = _get_int_iname_size(orig_kernel, sweep_iname) + scan_size = _get_int_iname_size(orig_kernel, scan_param.sweep_iname) assert scan_size > 0 if scan_size == 1: return map_reduction_seq(red_realize_ctx, - expr, rec, callables_table, - nresults, arg_dtypes, reduction_dtypes, - guarding_predicates) - - outer_insn_inames = insn.within_inames + expr, nresults, arg_dtypes, reduction_dtypes) from loopy.kernel.data import LocalInameTagBase - outer_local_inames = tuple(oiname for oiname in outer_insn_inames + outer_local_inames = tuple( + oiname for oiname in red_realize_ctx.surrounding_within_inames if orig_kernel.iname_tags_of_type(oiname, LocalInameTagBase) - and oiname != sweep_iname) + and oiname != scan_param.sweep_iname) from pymbolic import var outer_local_iname_vars = tuple( @@ -1452,28 +1522,29 @@ def map_scan_local( track_iname = red_realize_ctx.var_name_gen( "{sweep_iname}__pre_scan" - .format(sweep_iname=sweep_iname)) + .format(sweep_iname=scan_param.sweep_iname)) _get_or_add_sweep_tracking_iname_and_domain( red_realize_ctx, - scan_iname, sweep_iname, sweep_min_value, scan_min_value, stride, + scan_param, track_iname) + red_realize_ctx.additional_iname_tags[track_iname] = frozenset() # {{{ add separate iname to carry out the scan # Doing this sheds any odd conditionals that may be active # on our scan_iname. - base_exec_iname = red_realize_ctx.var_name_gen(sweep_iname + "__scan") + base_exec_iname = red_realize_ctx.var_name_gen(scan_param.sweep_iname + "__scan") red_realize_ctx.domains.append(_make_slab_set(base_exec_iname, scan_size)) red_realize_ctx.additional_iname_tags[base_exec_iname] \ - = orig_kernel.iname_tags(sweep_iname) + = orig_kernel.iname_tags(scan_param.sweep_iname) # }}} read_var_names = _make_temporaries( red_realize_ctx=red_realize_ctx, - name_based_on="read_"+scan_iname+"_arg_{index}", + name_based_on="read_"+scan_param.scan_iname+"_arg_{index}", nvars=nresults, shape=(), dtypes=reduction_dtypes, @@ -1481,7 +1552,7 @@ def map_scan_local( acc_var_names = _make_temporaries( red_realize_ctx=red_realize_ctx, - name_based_on="acc_"+scan_iname, + name_based_on="acc_"+scan_param.scan_iname, nvars=nresults, shape=outer_local_iname_sizes + (scan_size,), dtypes=reduction_dtypes, @@ -1490,24 +1561,17 @@ def map_scan_local( acc_vars = tuple(var(n) for n in acc_var_names) read_vars = tuple(var(n) for n in read_var_names) - base_iname_deps = (outer_insn_inames - - frozenset(expr.inames) - frozenset([sweep_iname])) - - neutral, callables_table = expr.operation.neutral_element( - *arg_dtypes, callables_table=callables_table, - target=orig_kernel.target) - - init_insn_depends_on = insn.depends_on - - # FIXME: Explain why we care about global barriers here - if kernel_has_global_barriers(orig_kernel): - global_barrier = find_most_recent_global_barrier( - red_realize_ctx.kernel, insn.id) + base_iname_deps = ( + red_realize_ctx.surrounding_within_inames + - frozenset([scan_param.sweep_iname])) - if global_barrier is not None: - init_insn_depends_on |= frozenset([global_barrier]) + neutral, red_realize_ctx.boxed_callables_table[0] = \ + expr.operation.neutral_element(*arg_dtypes, + callables_table=red_realize_ctx.boxed_callables_table[0], + target=orig_kernel.target) - init_id = red_realize_ctx.insn_id_gen(f"{insn.id}_{scan_iname}_init") + init_id = red_realize_ctx.insn_id_gen( + f"{red_realize_ctx.id_prefix}_{scan_param.scan_iname}_init") init_insn = make_assignment( id=init_id, assignees=tuple( @@ -1515,8 +1579,8 @@ def map_scan_local( for acc_var in acc_vars), expression=neutral, within_inames=base_iname_deps | frozenset([base_exec_iname]), - within_inames_is_final=insn.within_inames_is_final, - depends_on=init_insn_depends_on, + within_inames_is_final=True, + depends_on=frozenset(), # Do not inherit predicates: Those might read variables # that may not yet be set, and we don't have a great way # of figuring out what the dependencies of the accumulator @@ -1528,57 +1592,88 @@ def map_scan_local( ) red_realize_ctx.additional_insns.append(init_insn) - transfer_insn_depends_on = {init_insn.id} | insn.depends_on + transfer_insn_depends_on = ( + frozenset({init_insn.id}) + | red_realize_ctx.surrounding_depends_on) - updated_inner_exprs = _preprocess_scan_arguments( - red_realize_ctx, - expr.expr, nresults, - scan_iname, track_iname, transfer_insn_depends_on, - insn_id_gen=red_realize_ctx.insn_id_gen) + transfer_red_realize_ctx = red_realize_ctx.new_subinstruction( + within_inames=( + red_realize_ctx.surrounding_within_inames + | frozenset({scan_param.scan_iname})), + depends_on=red_realize_ctx.surrounding_depends_on) - from loopy.symbolic import Reduction + reduction_expr = red_realize_ctx.mapper( + expr.expr, red_realize_ctx=transfer_red_realize_ctx, + nresults=1) - from loopy.symbolic import pw_aff_to_expr - sweep_min_value_expr = pw_aff_to_expr(sweep_min_value) + updated_inner_exprs, transfer_insn_depends_on = _preprocess_scan_arguments( + red_realize_ctx, + reduction_expr, nresults, + scan_param.scan_iname, track_iname, transfer_insn_depends_on) - transfer_id = red_realize_ctx.insn_id_gen(f"{insn.id}_{scan_iname}_transfer") - transfer_insn = make_assignment( - id=transfer_id, - assignees=tuple( - acc_var[outer_local_iname_vars - + (var(sweep_iname) - sweep_min_value_expr,)] - for acc_var in acc_vars), - expression=Reduction( + from loopy.symbolic import Reduction + pre_scan_reduction = Reduction( operation=expr.operation, inames=(track_iname,), expr=_strip_if_scalar(acc_vars, updated_inner_exprs), allow_simultaneous=False, - ), - within_inames=outer_insn_inames - frozenset(expr.inames), - within_inames_is_final=insn.within_inames_is_final, - depends_on=frozenset(transfer_insn_depends_on), - no_sync_with=frozenset([(init_id, "any")]) | insn.no_sync_with, - predicates=insn.predicates, - ) + ) - red_realize_ctx.additional_insns.append(transfer_insn) + pre_scan_result = red_realize_ctx.mapper( + pre_scan_reduction, red_realize_ctx=transfer_red_realize_ctx, + nresults=len(acc_vars)) - prev_id = transfer_id + from loopy.symbolic import pw_aff_to_expr + sweep_lower_bound_expr = pw_aff_to_expr(scan_param.sweep_lower_bound) + + if nresults == 1: + assert not isinstance(pre_scan_result, tuple) + pre_scan_result = (pre_scan_result,) + + transfer_ids = frozenset() + for acc_var, pre_scan_result_i in zip(acc_vars, pre_scan_result): + transfer_id = red_realize_ctx.insn_id_gen( + f"{red_realize_ctx.id_prefix}_{scan_param.scan_iname}_transfer") + transfer_insn = make_assignment( + id=transfer_id, + assignees=(acc_var[outer_local_iname_vars + + (var(scan_param.sweep_iname) - sweep_lower_bound_expr,)],), + expression=pre_scan_result_i, + within_inames=( + red_realize_ctx.surrounding_within_inames + | transfer_red_realize_ctx.surrounding_insn_add_within_inames + | frozenset({scan_param.sweep_iname})), + within_inames_is_final=True, + depends_on=( + transfer_insn_depends_on + | transfer_red_realize_ctx.surrounding_insn_add_depends_on), + no_sync_with=( + frozenset([(init_id, "any")]) + | transfer_red_realize_ctx.surrounding_insn_add_no_sync_with), + predicates=red_realize_ctx.surrounding_predicates, + ) + + red_realize_ctx.additional_insns.append(transfer_insn) + transfer_ids = transfer_ids | frozenset({transfer_id}) + + del transfer_id + + prev_ids = transfer_ids istage = 0 cur_size = 1 while cur_size < scan_size: stage_exec_iname = red_realize_ctx.var_name_gen( - "%s__scan_s%d" % (sweep_iname, istage)) + f"{scan_param.sweep_iname}__scan_s{istage}") red_realize_ctx.domains.append( _make_slab_set_from_range(stage_exec_iname, cur_size, scan_size)) red_realize_ctx.additional_iname_tags[stage_exec_iname] \ - = orig_kernel.iname_tags(sweep_iname) + = orig_kernel.iname_tags(scan_param.sweep_iname) for read_var, acc_var in zip(read_vars, acc_vars): read_stage_id = red_realize_ctx.insn_id_gen( - "scan_%s_read_stage_%d" % (scan_iname, istage)) + f"scan_{scan_param.scan_iname}_read_stage_{istage}") read_stage_insn = make_assignment( id=read_stage_id, @@ -1589,9 +1684,9 @@ def map_scan_local( + (var(stage_exec_iname) - cur_size,)]), within_inames=( base_iname_deps | frozenset([stage_exec_iname])), - within_inames_is_final=insn.within_inames_is_final, - depends_on=frozenset([prev_id]), - predicates=insn.predicates, + within_inames_is_final=True, + depends_on=prev_ids, + predicates=red_realize_ctx.surrounding_predicates, ) if cur_size == 1: @@ -1601,22 +1696,22 @@ def map_scan_local( read_stage_insn = read_stage_insn.copy( no_sync_with=( read_stage_insn.no_sync_with - | frozenset([(transfer_id, "any")]))) + | frozenset([(tid, "any") for tid in transfer_ids]))) red_realize_ctx.additional_insns.append(read_stage_insn) - prev_id = read_stage_id + prev_ids = frozenset({read_stage_id}) write_stage_id = red_realize_ctx.insn_id_gen( - "scan_%s_write_stage_%d" % (scan_iname, istage)) + f"scan_{scan_param.scan_iname}_write_stage_{istage}") - expression, callables_table = expr.operation( + expression, red_realize_ctx.boxed_callables_table[0] = expr.operation( arg_dtypes, _strip_if_scalar(acc_vars, read_vars), _strip_if_scalar(acc_vars, tuple( acc_var[ outer_local_iname_vars + (var(stage_exec_iname),)] for acc_var in acc_vars)), - callables_table, + red_realize_ctx.boxed_callables_table[0], orig_kernel.target) write_stage_insn = make_assignment( @@ -1627,58 +1722,52 @@ def map_scan_local( expression=expression, within_inames=( base_iname_deps | frozenset([stage_exec_iname])), - within_inames_is_final=insn.within_inames_is_final, - depends_on=frozenset([prev_id]), - predicates=insn.predicates, + within_inames_is_final=True, + depends_on=prev_ids, + predicates=red_realize_ctx.surrounding_predicates, ) red_realize_ctx.additional_insns.append(write_stage_insn) - prev_id = write_stage_id + prev_ids = frozenset({write_stage_id}) cur_size *= 2 istage += 1 - red_realize_ctx.new_insn_add_depends_on.add(prev_id) - red_realize_ctx.new_insn_add_within_inames.add(sweep_iname) + red_realize_ctx.surrounding_insn_add_depends_on.update(prev_ids) + red_realize_ctx.surrounding_insn_add_within_inames.add(scan_param.sweep_iname) - output_idx = var(sweep_iname) - sweep_min_value_expr + output_idx = var(scan_param.sweep_iname) - sweep_lower_bound_expr if nresults == 1: assert len(acc_vars) == 1 - return (acc_vars[0][outer_local_iname_vars + (output_idx,)], - callables_table) + return acc_vars[0][outer_local_iname_vars + (output_idx,)] else: return [acc_var[outer_local_iname_vars + (output_idx,)] - for acc_var in acc_vars], callables_table + for acc_var in acc_vars] # }}} # {{{ top-level dispatch among reduction types -def map_reduction( - expr, *, rec, - callables_table, red_realize_ctx, - guarding_predicates, nresults): - insn = red_realize_ctx.insn - - # Only expand one level of reduction at a time, going from outermost to - # innermost. Otherwise we get the (iname + insn) dependencies wrong. +def map_reduction(expr, *, red_realize_ctx, nresults): + kernel_with_updated_domains = red_realize_ctx.kernel.copy( + domains=red_realize_ctx.domains) from loopy.type_inference import ( infer_arg_and_reduction_dtypes_for_reduction_expression) arg_dtypes, reduction_dtypes = ( infer_arg_and_reduction_dtypes_for_reduction_expression( - red_realize_ctx.kernel, expr, callables_table, + kernel_with_updated_domains, expr, + red_realize_ctx.boxed_callables_table[0], red_realize_ctx.unknown_types_ok)) - outer_insn_inames = insn.within_inames - bad_inames = frozenset(expr.inames) & outer_insn_inames + bad_inames = frozenset(expr.inames) & red_realize_ctx.surrounding_within_inames if bad_inames: raise LoopyError("reduction used within loop(s) that it was " "supposed to reduce over: " + ", ".join(bad_inames)) - iname_classes = _classify_reduction_inames(red_realize_ctx.kernel, expr.inames) + iname_classes = _classify_reduction_inames(red_realize_ctx, expr.inames) n_sequential = len(iname_classes.sequential) n_local_par = len(iname_classes.local_parallel) @@ -1698,7 +1787,8 @@ def map_reduction( # Try to determine scan candidate information (sweep iname, scan # iname, etc). scan_param = _try_infer_scan_candidate_from_expr( - red_realize_ctx.kernel, expr, outer_insn_inames, + kernel_with_updated_domains, expr, + red_realize_ctx.surrounding_within_inames, sweep_iname=red_realize_ctx.force_outer_iname_for_scan) except ValueError as v: @@ -1707,7 +1797,7 @@ def map_reduction( else: # Ensures the reduction is triangular (somewhat expensive). may_be_implemented_as_scan, error = _check_reduction_is_triangular( - red_realize_ctx.kernel, expr, scan_param) + kernel_with_updated_domains, expr, scan_param) if not may_be_implemented_as_scan: _error_if_force_scan_on(ReductionIsNotTriangularError, error) @@ -1751,7 +1841,7 @@ def map_reduction( # to reduce over. It's rather similar to an array with () shape in # numpy.) - return expr.expr, callables_table + return expr.expr if may_be_implemented_as_scan: assert red_realize_ctx.force_scan or red_realize_ctx.automagic_scans_ok @@ -1759,14 +1849,13 @@ def map_reduction( # We require the "scan" iname to be tagged sequential. if n_sequential: sweep_iname = scan_param.sweep_iname - sweep_class = _classify_reduction_inames( - red_realize_ctx.orig_kernel, (sweep_iname,)) + sweep_class = _classify_reduction_inames(red_realize_ctx, (sweep_iname,)) sequential = sweep_iname in sweep_class.sequential parallel = sweep_iname in sweep_class.local_parallel bad_parallel = sweep_iname in sweep_class.nonlocal_parallel - if sweep_iname not in outer_insn_inames: + if sweep_iname not in red_realize_ctx.surrounding_within_inames: _error_if_force_scan_on(LoopyError, "Sweep iname '%s' was detected, but is not an iname " "for the instruction." % sweep_iname) @@ -1778,25 +1867,11 @@ def map_reduction( ", ".join(tag.key for tag in red_realize_ctx.kernel.iname_tags(sweep_iname)))) elif parallel: - return map_scan_local( - red_realize_ctx, - expr, rec, callables_table, nresults, - arg_dtypes, reduction_dtypes, - sweep_iname, scan_param.scan_iname, - scan_param.sweep_lower_bound, - scan_param.scan_lower_bound, - scan_param.stride, - guarding_predicates) + return map_scan_local(red_realize_ctx, expr, nresults, + arg_dtypes, reduction_dtypes, scan_param) elif sequential: - return map_scan_seq( - red_realize_ctx, - expr, rec, callables_table, nresults, - arg_dtypes, reduction_dtypes, sweep_iname, - scan_param.scan_iname, - scan_param.sweep_lower_bound, - scan_param.scan_lower_bound, - scan_param.stride, - guarding_predicates) + return map_scan_seq(red_realize_ctx, expr, nresults, + arg_dtypes, reduction_dtypes, scan_param) # fallthrough to reduction implementation @@ -1814,15 +1889,13 @@ def map_reduction( assert n_local_par == 0 return map_reduction_seq( red_realize_ctx, - expr, rec, callables_table, - nresults, arg_dtypes, reduction_dtypes, - guarding_predicates) + expr, nresults, arg_dtypes, reduction_dtypes) else: assert n_local_par > 0 return map_reduction_local( red_realize_ctx, - expr, rec, callables_table, nresults, arg_dtypes, - reduction_dtypes, guarding_predicates) + expr, nresults, arg_dtypes, + reduction_dtypes) # }}} @@ -1842,7 +1915,7 @@ def realize_reduction_for_single_kernel(kernel, callables_table, insn_id_gen = kernel.get_instruction_id_generator() var_name_gen = kernel.get_var_name_generator() - cb_mapper = RealizeReductionCallbackMapper(map_reduction, callables_table) + cb_mapper = RealizeReductionCallbackMapper(map_reduction) insn_queue = kernel.instructions[:] domains = kernel.domains[:] @@ -1855,6 +1928,8 @@ def realize_reduction_for_single_kernel(kernel, callables_table, insn = insn_queue.pop(0) red_realize_ctx = _ReductionRealizationContext( + mapper=cb_mapper, + force_scan=force_scan, automagic_scans_ok=automagic_scans_ok, unknown_types_ok=unknown_types_ok, @@ -1862,7 +1937,8 @@ def realize_reduction_for_single_kernel(kernel, callables_table, orig_kernel=orig_kernel, kernel=kernel, - insn=insn, + + id_prefix=insn.id, insn_id_gen=insn_id_gen, var_name_gen=var_name_gen, @@ -1871,14 +1947,20 @@ def realize_reduction_for_single_kernel(kernel, callables_table, additional_insns=[], domains=domains, additional_iname_tags={}, + boxed_callables_table=[callables_table], inames_added_for_scan=inames_added_for_scan, - new_insn_add_depends_on=set(), - new_insn_add_no_sync_with=set(), - new_insn_add_within_inames=set(), + surrounding_within_inames=insn.within_inames, + surrounding_depends_on=insn.depends_on, + surrounding_no_sync_with=insn.no_sync_with, + surrounding_predicates=insn.predicates, - were_changes_made=False, + surrounding_insn_add_within_inames=set(), + surrounding_insn_add_depends_on=set(), + surrounding_insn_add_no_sync_with=set(), + + _change_flag=_ChangeFlag(changes_made=False) ) if insn_id_filter is not None and insn.id != insn_id_filter \ @@ -1892,15 +1974,11 @@ def realize_reduction_for_single_kernel(kernel, callables_table, from loopy.symbolic import Reduction if isinstance(insn.expression, Reduction) and nresults > 1: new_expressions = cb_mapper(insn.expression, - callables_table=cb_mapper.callables_table, red_realize_ctx=red_realize_ctx, - guarding_predicates=insn.predicates, nresults=nresults) else: new_expressions = cb_mapper(insn.expression, - callables_table=cb_mapper.callables_table, red_realize_ctx=red_realize_ctx, - guarding_predicates=insn.predicates, nresults=1), if red_realize_ctx.were_changes_made: @@ -1911,17 +1989,17 @@ def realize_reduction_for_single_kernel(kernel, callables_table, kernel_changed = True - insn_id_replacements = {} + callables_table = red_realize_ctx.boxed_callables_table[0] result_assignment_dep_on = ( insn.depends_on - | frozenset(red_realize_ctx.new_insn_add_depends_on)) + | frozenset(red_realize_ctx.surrounding_insn_add_depends_on)) kwargs = insn.get_copy_kwargs( no_sync_with=insn.no_sync_with - | frozenset(red_realize_ctx.new_insn_add_no_sync_with), + | frozenset(red_realize_ctx.surrounding_insn_add_no_sync_with), within_inames=( insn.within_inames - | red_realize_ctx.new_insn_add_within_inames)) + | red_realize_ctx.surrounding_insn_add_within_inames)) kwargs.pop("id") kwargs.pop("depends_on") @@ -1931,6 +2009,8 @@ def realize_reduction_for_single_kernel(kernel, callables_table, kwargs.pop("temp_var_type", None) kwargs.pop("temp_var_types", None) + insn_id_replacements = {} + if isinstance(insn.expression, Reduction) and nresults > 1: result_assignment_ids = [ insn_id_gen(insn.id) for i in range(nresults)] @@ -1962,10 +2042,32 @@ def realize_reduction_for_single_kernel(kernel, callables_table, **kwargs) ] - insn_queue = ( - red_realize_ctx.additional_insns - + replacement_insns - + insn_queue) + additional_insns = red_realize_ctx.additional_insns + + # {{{ make additional insns depend on most recent global barrier + + # FIXME This is weird and hokey and ad-hoc and probably broken. + # I *think* the idea is to keep a reduction/scan implementation + # from crossing a global barrier, because that would be costly. + + # check first that the original kernel had global barriers + # if not, we don't need to check. Since the function + # kernel_has_global_barriers is cached, we don't do + # extra work compared to not checking. + + from loopy.kernel.tools import ( + kernel_has_global_barriers, find_most_recent_global_barrier) + + if kernel_has_global_barriers(orig_kernel): + global_barrier = find_most_recent_global_barrier(kernel, insn.id) + + if global_barrier is not None: + gb_dep = frozenset([global_barrier]) + additional_insns = [addl_insn.copy( + depends_on=addl_insn.depends_on | gb_dep) + for addl_insn in additional_insns] + + # }}} # The reduction expander needs an up-to-date kernel # object to find dependencies. Keep kernel up-to-date. @@ -1980,6 +2082,9 @@ def realize_reduction_for_single_kernel(kernel, callables_table, replace_instruction_ids_in_insn(insn, insn_id_replacements) for insn in insn_queue] + finished_insns.extend(additional_insns) + finished_insns.extend(replacement_insns) + kernel = kernel.copy( instructions=finished_insns + insn_queue, temporary_variables=new_temporary_variables, @@ -1993,19 +2098,17 @@ def realize_reduction_for_single_kernel(kernel, callables_table, else: # nothing happened, we're done with insn - assert not red_realize_ctx.new_insn_add_depends_on + assert not red_realize_ctx.surrounding_insn_add_depends_on finished_insns.append(insn) - if kernel_changed: - kernel = kernel.copy(instructions=finished_insns) - else: + if not kernel_changed: return orig_kernel, callables_table kernel = _hackily_ensure_multi_assignment_return_values_are_scoped_private( kernel) - return kernel, cb_mapper.callables_table + return kernel, callables_table # }}} diff --git a/test/test_scan.py b/test/test_scan.py index 94778ef4d..f5aa8a7c2 100644 --- a/test/test_scan.py +++ b/test/test_scan.py @@ -221,12 +221,8 @@ def test_local_parallel_scan(ctx_factory, n): knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.realize_reduction(knl, force_scan=True) - knl = lp.realize_reduction(knl) - knl = lp.add_dtypes(knl, dict(a=int)) - print(knl) - evt, (a,) = knl(queue, a=np.arange(n)) assert (a == np.cumsum(np.arange(n)**2)).all() @@ -246,7 +242,6 @@ def test_local_parallel_scan_with_nonzero_lower_bounds(ctx_factory): knl = lp.fix_parameters(knl, n=16) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.realize_reduction(knl, force_scan=True) - knl = lp.realize_reduction(knl) knl = lp.add_dtypes(knl, dict(a=int)) evt, (out,) = knl(queue, a=np.arange(1, 17)) -- GitLab