diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 3712d678b6e2a663636c5f7a1c3302af934215e4..d82b2b3520b6f539dfa59fd02a5d0b1f1a1aedeb 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -1125,140 +1125,6 @@ def has_schedulable_iname_nesting(kernel): # }}} -# {{{ rename_inames - -@for_each_kernel -def rename_inames(kernel, old_inames, new_iname, existing_ok=False, within=None): - """ - :arg old_inames: A collection of inames that must be renamed to **new_iname**. - :arg within: a stack match as understood by - :func:`loopy.match.parse_stack_match`. - :arg existing_ok: execute even if *new_iname* already exists - """ - from collections.abc import Collection - if (isinstance(old_inames, str) - or not isinstance(old_inames, Collection)): - raise LoopyError("'old_inames' must be a collection of strings, " - f"got '{type(old_inames)}'.") - - if new_iname in old_inames: - raise LoopyError("new iname is part of inames being renamed") - - if new_iname in (kernel.all_variable_names() - kernel.all_inames()): - raise LoopyError(f"New iname '{new_iname}' is already a variable in the" - "kernel") - - if any((len(insn.within_inames & frozenset(old_inames)) > 1) - for insn in kernel.instructions): - raise LoopyError("old_inames contains nested inames" - " -- renaming is illegal.") - - # sort to have deterministic implementation. - old_inames = sorted(old_inames) - - var_name_gen = kernel.get_var_name_generator() - - # FIXME: Distinguish existing iname vs. existing other variable - does_exist = new_iname in kernel.all_inames() - - if not (frozenset(old_inames) <= kernel.all_inames()): - raise LoopyError(f"old inames {frozenset(old_inames) - kernel.all_inames()}" - " do not exist.") - - if does_exist and not existing_ok: - raise LoopyError(f"iname '{new_iname}' conflicts with an existing identifier" - " --cannot rename") - - if not does_exist: - # {{{ rename old_inames[0] -> new_iname - # so that the code below can focus on "merging" inames that already exist - - kernel = duplicate_inames( - kernel, old_inames[0], within=within, new_inames=[new_iname]) - kernel = remove_unused_inames(kernel, old_inames[0]) - - # old_iname[0] is already renamed to new_iname => do not rename again. - old_inames = old_inames[1:] - - # }}} - - del does_exist - assert new_iname in kernel.all_inames() - - for old_iname in old_inames: - # {{{ check that the domains match up - - dom = kernel.get_inames_domain(frozenset((old_iname, new_iname))) - - var_dict = dom.get_var_dict() - _, old_idx = var_dict[old_iname] - _, new_idx = var_dict[new_iname] - - par_idx = dom.dim(dim_type.param) - dom_old = dom.move_dims( - dim_type.param, par_idx, dim_type.set, old_idx, 1) - dom_old = dom_old.move_dims( - dim_type.set, dom_old.dim(dim_type.set), dim_type.param, par_idx, 1) - dom_old = dom_old.project_out( - dim_type.set, new_idx if new_idx < old_idx else new_idx - 1, 1) - - par_idx = dom.dim(dim_type.param) - dom_new = dom.move_dims( - dim_type.param, par_idx, dim_type.set, new_idx, 1) - dom_new = dom_new.move_dims( - dim_type.set, dom_new.dim(dim_type.set), dim_type.param, par_idx, 1) - dom_new = dom_new.project_out( - dim_type.set, old_idx if old_idx < new_idx else old_idx - 1, 1) - - if not (dom_old <= dom_new and dom_new <= dom_old): - raise LoopyError( - "inames {old} and {new} do not iterate over the same domain" - .format(old=old_iname, new=new_iname)) - - # }}} - - from pymbolic import var - subst_dict = {old_iname: var(new_iname) for old_iname in old_inames} - - from loopy.match import parse_stack_match - within = parse_stack_match(within) - - from pymbolic.mapper.substitutor import make_subst_func - rule_mapping_context = SubstitutionRuleMappingContext( - kernel.substitutions, var_name_gen) - smap = RuleAwareSubstitutionMapper(rule_mapping_context, - make_subst_func(subst_dict), within) - - from loopy.kernel.instruction import MultiAssignmentBase - - def does_insn_involve_iname(kernel, insn, *args): - return (not isinstance(insn, MultiAssignmentBase) - or frozenset(old_inames) & insn.dependency_names() - or frozenset(old_inames) & insn.reduction_inames()) - - kernel = rule_mapping_context.finish_kernel( - smap.map_kernel(kernel, within=does_insn_involve_iname)) - - new_instructions = [insn.copy(within_inames=((insn.within_inames - - frozenset(old_inames)) - | frozenset([new_iname]))) - if ((len(frozenset(old_inames) & insn.within_inames) != 0) - and within(kernel, insn, ())) - else insn - for insn in kernel.instructions] - - kernel = kernel.copy(instructions=new_instructions) - kernel = remove_unused_inames(kernel, old_inames) - - return kernel - - -def rename_iname(kernel, old_iname, new_iname, existing_ok=False, within=None): - return rename_inames(kernel, [old_iname], new_iname, existing_ok, within) - -# }}} - - # {{{ remove unused inames def get_used_inames(kernel): @@ -2422,4 +2288,138 @@ def add_inames_for_unused_hw_axes(kernel, within=None): return kernel.copy(instructions=new_insns) + +# {{{ rename_inames + +@for_each_kernel +@remove_any_newly_unused_inames +def rename_inames(kernel, old_inames, new_iname, existing_ok=False, within=None): + """ + :arg old_inames: A collection of inames that must be renamed to **new_iname**. + :arg within: a stack match as understood by + :func:`loopy.match.parse_stack_match`. + :arg existing_ok: execute even if *new_iname* already exists + """ + from collections.abc import Collection + if (isinstance(old_inames, str) + or not isinstance(old_inames, Collection)): + raise LoopyError("'old_inames' must be a collection of strings, " + f"got '{type(old_inames)}'.") + + if new_iname in old_inames: + raise LoopyError("new iname is part of inames being renamed") + + if new_iname in (kernel.all_variable_names() - kernel.all_inames()): + raise LoopyError(f"New iname '{new_iname}' is already a variable in the" + "kernel") + + if any((len(insn.within_inames & frozenset(old_inames)) > 1) + for insn in kernel.instructions): + raise LoopyError("old_inames contains nested inames" + " -- renaming is illegal.") + + # sort to have deterministic implementation. + old_inames = sorted(old_inames) + + var_name_gen = kernel.get_var_name_generator() + + # FIXME: Distinguish existing iname vs. existing other variable + does_exist = new_iname in kernel.all_inames() + + if not (frozenset(old_inames) <= kernel.all_inames()): + raise LoopyError(f"old inames {frozenset(old_inames) - kernel.all_inames()}" + " do not exist.") + + if does_exist and not existing_ok: + raise LoopyError(f"iname '{new_iname}' conflicts with an existing identifier" + " --cannot rename") + + if not does_exist: + # {{{ rename old_inames[0] -> new_iname + # so that the code below can focus on "merging" inames that already exist + + kernel = duplicate_inames( + kernel, old_inames[0], within=within, new_inames=[new_iname]) + + # old_iname[0] is already renamed to new_iname => do not rename again. + old_inames = old_inames[1:] + + # }}} + + del does_exist + assert new_iname in kernel.all_inames() + + for old_iname in old_inames: + # {{{ check that the domains match up + + dom = kernel.get_inames_domain(frozenset((old_iname, new_iname))) + + var_dict = dom.get_var_dict() + _, old_idx = var_dict[old_iname] + _, new_idx = var_dict[new_iname] + + par_idx = dom.dim(dim_type.param) + dom_old = dom.move_dims( + dim_type.param, par_idx, dim_type.set, old_idx, 1) + dom_old = dom_old.move_dims( + dim_type.set, dom_old.dim(dim_type.set), dim_type.param, par_idx, 1) + dom_old = dom_old.project_out( + dim_type.set, new_idx if new_idx < old_idx else new_idx - 1, 1) + + par_idx = dom.dim(dim_type.param) + dom_new = dom.move_dims( + dim_type.param, par_idx, dim_type.set, new_idx, 1) + dom_new = dom_new.move_dims( + dim_type.set, dom_new.dim(dim_type.set), dim_type.param, par_idx, 1) + dom_new = dom_new.project_out( + dim_type.set, old_idx if old_idx < new_idx else old_idx - 1, 1) + + if not (dom_old <= dom_new and dom_new <= dom_old): + raise LoopyError( + "inames {old} and {new} do not iterate over the same domain" + .format(old=old_iname, new=new_iname)) + + # }}} + + from pymbolic import var + subst_dict = {old_iname: var(new_iname) for old_iname in old_inames} + + from loopy.match import parse_stack_match + within = parse_stack_match(within) + + from pymbolic.mapper.substitutor import make_subst_func + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, var_name_gen) + smap = RuleAwareSubstitutionMapper(rule_mapping_context, + make_subst_func(subst_dict), within) + + from loopy.kernel.instruction import MultiAssignmentBase + + def does_insn_involve_iname(kernel, insn, *args): + return (not isinstance(insn, MultiAssignmentBase) + or frozenset(old_inames) & insn.dependency_names() + or frozenset(old_inames) & insn.reduction_inames()) + + kernel = rule_mapping_context.finish_kernel( + smap.map_kernel(kernel, within=does_insn_involve_iname)) + + new_instructions = [insn.copy(within_inames=((insn.within_inames + - frozenset(old_inames)) + | frozenset([new_iname]))) + if ((len(frozenset(old_inames) & insn.within_inames) != 0) + and within(kernel, insn, ())) + else insn + for insn in kernel.instructions] + + kernel = kernel.copy(instructions=new_instructions) + + return kernel + + +def rename_iname(kernel, old_iname, new_iname, existing_ok=False, within=None): + return rename_inames(kernel, [old_iname], new_iname, existing_ok, within) + +# }}} + + # vim: foldmethod=marker