diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index d76e6f9fea63c704bb06dcab5edc109145e27d15..2c663aa35a752b3fe3b4e54b7b16d2fa0b167b38 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -20,7 +20,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import re import collections from pytools import ImmutableRecord @@ -375,41 +374,7 @@ class Program(TranslationUnit): # }}} -# {{{ next_indexed_function_id - -def next_indexed_function_id(function_id): - """ - Returns an instance of :class:`str` with the next indexed-name in the - sequence for the name of *function_id*. - - *Example:* ``'sin_0'`` will return ``'sin_1'``. - - :arg function_id: Either an instance of :class:`str`. - """ - - # {{{ sanity checks - - assert isinstance(function_id, str) - - # }}} - - func_name = re.compile(r"^(?P<alpha>\S+?)_(?P<num>\d+?)$") - - match = func_name.match(function_id) - - if match is None: - if function_id[-1] == "_": - return f"{function_id}0" - else: - return f"{function_id}_0" - - return "{alpha}_{num}".format(alpha=match.group("alpha"), - num=int(match.group("num"))+1) - -# }}} - - -# {{{ rename_resolved_functions_in_a_single_kernel +# {{{ rename resolved functions class ResolvedFunctionRenamer(RuleAwareIdentityMapper): """ @@ -482,10 +447,12 @@ class CallablesIDCollector(CombineMapper): return callables_in_insn + def map_type_cast(self, expr): + return self.rec(expr) + map_variable = map_constant map_function_symbol = map_constant map_tagged_variable = map_constant - map_type_cast = map_constant def _get_reachable_callable_ids_for_knl(knl, callables): @@ -508,8 +475,25 @@ def _get_reachable_callable_ids(callables, entrypoints): # {{{ CallablesInferenceContext +def get_all_subst_names(callables): + """ + Returns a :class:`set` of all substitution rule names in the callable + kernels of *callables*. + + :arg callables: A mapping from function identifiers to + :class:`~loopy.kernel.function_interface.InKernelCallable`. + """ + return set().union(*(set(clbl.subkernel.substitutions.keys()) + for clbl in callables.values() + if isinstance(clbl, CallableKernel))) + + def make_clbl_inf_ctx(callables, entrypoints): - return CallablesInferenceContext(callables) + from pytools import UniqueNameGenerator + all_substs = get_all_subst_names(callables) + ung = UniqueNameGenerator(set(callables.keys()) | all_substs) + + return CallablesInferenceContext(callables, ung) class CallablesInferenceContext(ImmutableRecord): @@ -538,12 +522,13 @@ class CallablesInferenceContext(ImmutableRecord): .. automethod:: __getitem__ """ def __init__(self, callables, + clbl_name_gen, renames=collections.defaultdict(frozenset), new_entrypoints=frozenset()): assert isinstance(callables, collections.abc.Mapping) - callables = dict(callables) - super().__init__(callables=callables, + super().__init__(callables=dict(callables), + clbl_name_gen=clbl_name_gen, renames=renames, new_entrypoints=new_entrypoints) @@ -607,14 +592,8 @@ class CallablesInferenceContext(ImmutableRecord): # }}} - # {{{ must allocate a new clbl in the namespace => find a unique id for it - - unique_function_id = old_function_id - - while unique_function_id in self.callables: - unique_function_id = next_indexed_function_id(unique_function_id) - - # }}} + # must allocate a new clbl in the namespace => find a unique id for it + unique_function_id = self.clbl_name_gen(old_function_id) updated_callables = self.callables.copy() updated_callables[unique_function_id] = new_clbl @@ -668,12 +647,7 @@ class CallablesInferenceContext(ImmutableRecord): new_callables = dict(program.callables_table) for c in callees_with_old_entrypoint_names: - unique_func_id = c - - while unique_func_id in self.callables: - unique_func_id = next_indexed_function_id(unique_func_id) - - todo_renames[c] = unique_func_id + todo_renames[c] = self.clbl_name_gen(c) for e in self.new_entrypoints: # note renames to "rollback" the renaming of entrypoints @@ -769,6 +743,12 @@ def for_each_kernel(transform): def update_table(callables_table, clbl_id, clbl): + """ + Returns a tuple ``new_clbl_id, new_callables_table`` where + *new_callables_table* is a copy of *callables_table* with *clbl* in its + namespace. *clbl* is referred in *new_callables_table*'s namespace by + *new_clbl_id*. + """ from loopy.kernel.function_interface import InKernelCallable assert isinstance(clbl, InKernelCallable) @@ -776,12 +756,19 @@ def update_table(callables_table, clbl_id, clbl): if c == clbl: return i, callables_table - while clbl_id in callables_table: - clbl_id = next_indexed_function_id(clbl_id) + if isinstance(clbl_id, ReductionOpFunction): + new_clbl_id = clbl_id.copy() + else: + assert isinstance(clbl_id, str) + from pytools import UniqueNameGenerator + all_substs = get_all_subst_names(callables_table) + ung = UniqueNameGenerator(set(callables_table.keys()) | all_substs) + new_clbl_id = ung(clbl_id) - callables_table[clbl_id] = clbl + new_callables_table = callables_table.copy() + new_callables_table[new_clbl_id] = clbl - return clbl_id, callables_table + return new_clbl_id, new_callables_table # }}}