diff --git a/doc/ref_translation_unit.rst b/doc/ref_translation_unit.rst index 9d7c49158e7953f0be23fc6081b33aca47d14b1f..631c5756130026cf10f4708f28ff76e1a5539722 100644 --- a/doc/ref_translation_unit.rst +++ b/doc/ref_translation_unit.rst @@ -4,3 +4,8 @@ TranslationUnit =============== .. autoclass:: TranslationUnit + +Reference +--------- + +.. automodule:: loopy.translation_unit diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 7efeffa78a19e1720ff3646d8f840b9906521c20..53ddcefe1f6ef93bb6b38f6f4ad65c83e56e688d 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2151,12 +2151,10 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): Infers the :attr:`loopy` """ - def __init__(self, rule_mapping_context, caller_kernel, - callables_table): - super().__init__( - rule_mapping_context) + def __init__(self, rule_mapping_context, caller_kernel, clbl_inf_ctx): + super().__init__(rule_mapping_context) self.caller_kernel = caller_kernel - self.callables_table = callables_table + self.clbl_inf_ctx = clbl_inf_ctx def map_call(self, expr, expn_state, assignees=None): from pymbolic.primitives import Call, CallWithKwargs, Variable @@ -2185,7 +2183,7 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): arg_id: get_arg_descriptor_for_expression( self.caller_kernel, arg) for arg_id, arg in arg_id_to_val.items()} - in_knl_callable = self.callables_table[expr.function.name] + in_knl_callable = self.clbl_inf_ctx[expr.function.name] # {{{ translating descriptor expressions to the callable's namespace @@ -2219,14 +2217,14 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): # }}} # specializing the function according to the parameter description - new_in_knl_callable, self.callables_table = ( + new_in_knl_callable, self.clbl_inf_ctx = ( in_knl_callable.with_descrs( - arg_id_to_descr, self.callables_table)) + arg_id_to_descr, self.clbl_inf_ctx)) # find the deps of the new in kernel callablen and add those arguments to - self.callables_table, new_func_id = ( - self.callables_table.with_callable( + self.clbl_inf_ctx, new_func_id = ( + self.clbl_inf_ctx.with_callable( expr.function.function, new_in_knl_callable)) @@ -2306,7 +2304,7 @@ def traverse_to_infer_arg_descr(kernel, callables_table): descr_inferred_kernel = rule_mapping_context.finish_kernel( arg_descr_inf_mapper.map_kernel(kernel)) - return descr_inferred_kernel, arg_descr_inf_mapper.callables_table + return descr_inferred_kernel, arg_descr_inf_mapper.clbl_inf_ctx def infer_arg_descr(program): @@ -2324,9 +2322,7 @@ def infer_arg_descr(program): program = resolve_callables(program) clbl_inf_ctx = make_clbl_inf_ctx(program.callables_table, - program.entrypoints) - - renamed_entrypoints = set() + program.entrypoints) for e in program.entrypoints: def _tuple_or_None(s): @@ -2350,10 +2346,10 @@ def infer_arg_descr(program): raise NotImplementedError() new_callable, clbl_inf_ctx = program.callables_table[e].with_descrs( arg_id_to_descr, clbl_inf_ctx) - clbl_inf_ctx, new_name = clbl_inf_ctx.with_callable(e, new_callable) - renamed_entrypoints.add(new_name.name) + clbl_inf_ctx, new_name = clbl_inf_ctx.with_callable(e, new_callable, + is_entrypoint=True) - return clbl_inf_ctx.finish_program(program, renamed_entrypoints) + return clbl_inf_ctx.finish_program(program) # }}} diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 14ed2d4007979dfe88b40b1c7ff7526e03ba37a8..9ce69b4ae81a5972cd2505a9ef56fadfeb056dce 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -44,10 +44,12 @@ from pyrsistent import pmap, PMap __doc__ = """ -.. currentmodule:: loopy +.. currentmodule:: loopy.translation_unit .. autoclass:: TranslationUnit +.. autoclass:: CallablesInferenceContext + .. autofunction:: make_program .. autofunction:: iterate_over_kernels_if_given_program @@ -396,10 +398,10 @@ class Program(TranslationUnit): # }}} -def next_indexed_function_identifier(function_id): +def next_indexed_function_id(function_id): """ Returns an instance of :class:`str` with the next indexed-name in the - sequence for the name of *function*. + sequence for the name of *function_id*. *Example:* ``'sin_0'`` will return ``'sin_1'``. @@ -462,9 +464,8 @@ def rename_resolved_functions_in_a_single_kernel(kernel, class CallablesIDCollector(CombineMapper): """ - Returns an instance of :class:`frozenset` containing instances of - :class:`loopy.kernel.function_interface.InKernelCallable` in the - :attr:``kernel`. + Mapper to collect function identifiers of all resolved callables in an + expression. """ def combine(self, values): import operator @@ -512,192 +513,207 @@ def _get_callable_ids_for_knl(knl, callables): def _get_callable_ids(callables, entrypoints): return frozenset().union(*( - _get_callable_ids_for_knl(callables[e].subkernel, callables) for e in - entrypoints)) + _get_callable_ids_for_knl(callables[e].subkernel, callables) + for e in entrypoints)) def make_clbl_inf_ctx(callables, entrypoints): - return CallablesInferenceContext(callables, _get_callable_ids(callables, - entrypoints)) + return CallablesInferenceContext(callables) class CallablesInferenceContext(ImmutableRecord): - def __init__(self, callables, old_callable_ids, history={}): + """ + Helper class for housekeeping a :attr:`loopy.TranslationUnit.callables_table` + while traversing through callables of :class:`loopy.TranslationUnit`. + + .. attribute:: callables + + A mapping from the callable names to instances of + :class:`loopy.kernel.function_interface.InKernelCallable`. + + .. attribute:: renames + + A mapping from old function identifiers to a :class:`frozenset` of new + function identifiers. + + .. attribute:: new_entrypoints + + A :class:`frozenset` of renamed entrypoint names. + + .. automethod:: with_callable + + .. automethod:: finish_program + + .. automethod:: __getitem__ + """ + def __init__(self, callables, + renames=collections.defaultdict(frozenset), + new_entrypoints=frozenset()): assert isinstance(callables, collections.abc.Mapping) callables = dict(callables) - super().__init__( - callables=callables, - old_callable_ids=old_callable_ids, - history=history) + super().__init__(callables=callables, + renames=renames, + new_entrypoints=new_entrypoints) # {{{ interface to perform edits on callables - def with_callable(self, function, in_kernel_callable): + def with_callable(self, old_function_id, new_clbl, + is_entrypoint=False): """ - Returns an instance of :class:`tuple` ``(new_self, new_function)``. + Updates the callable referred by *function_id*'s in *self*'s namespace + to *new_clbl*. - :arg function: An instance of :class:`pymbolic.primitives.Variable` or + :arg old_function_id: An instance of :class:`pymbolic.primitives.Variable` or :class:`loopy.library.reduction.ReductionOpFunction`. - :arg in_kernel_callable: An instance of - :class:`loopy.InKernelCallable`. + :arg new_clbl: An instance of + :class:`loopy.kernel.function_interface.InKernelCallable`. + + :returns: ``(new_self, new_function_id)`` is a copy of *self* with + *new_clbl* in its namespace. *new_clbl* would be referred by + *new_function_id* in *new_self*'s namespace. """ - # {{{ sanity checks + assert isinstance(old_function_id, (str, Variable, ReductionOpFunction)) + + if isinstance(old_function_id, Variable): + old_function_id = old_function_id.name + + renames = self.renames.copy() + + # if the callable already exists => return the function + # identifier corresponding to that callable. + for func_id, clbl in self.callables.items(): + if clbl == new_clbl: + renames[old_function_id] |= frozenset([func_id]) + if isinstance(func_id, str): + new_entrypoints = self.new_entrypoints + if is_entrypoint: + new_entrypoints |= frozenset([func_id]) + return (self.copy(renames=renames, + new_entrypoints=new_entrypoints), + Variable(func_id),) + else: + assert not is_entrypoint + assert isinstance(func_id, ReductionOpFunction) + return (self.copy(renames=renames), + func_id) + + # {{{ handle ReductionOpFunction + + if isinstance(old_function_id, ReductionOpFunction): + # FIXME: Check if we have 2 ArgMax functions + # with different types in the same kernel the generated code + # does not mess up the types. + assert not is_entrypoint + unique_function_id = old_function_id.copy() + updated_callables = self.callables.copy() + updated_callables[unique_function_id] = new_clbl + renames[old_function_id] |= frozenset([unique_function_id]) + + return (self.copy(callables=updated_callables, + renames=renames), + unique_function_id) - if isinstance(function, str): - function = Variable(function) + # }}} - assert isinstance(function, (Variable, ReductionOpFunction)) + # {{{ must allocate a new clbl in the namespace => find a unique id for it - # }}} + unique_function_id = old_function_id - history = self.history.copy() - - if in_kernel_callable in self.callables.values(): - # the callable already exists, hence return the function - # identifier corresponding to that callable. - for func_id, in_knl_callable in self.callables.items(): - if in_knl_callable == in_kernel_callable: - history[func_id] = function.name - if isinstance(func_id, str): - return ( - self.copy( - history=history), - Variable(func_id)) - else: - assert isinstance(func_id, ReductionOpFunction) - return ( - self.copy( - history=history), - func_id) - - assert False - else: - # {{{ handle ReductionOpFunction + while unique_function_id in self.callables: + unique_function_id = next_indexed_function_id(unique_function_id) - if isinstance(function, ReductionOpFunction): - # FIXME: Check if we have 2 ArgMax functions - # with different types in the same kernel the generated code - # does not mess up the types. - unique_function_identifier = function.copy() - updated_callables = self.callables.copy() - updated_callables[unique_function_identifier] = ( - in_kernel_callable) + # }}} - return ( - self.copy( - callables=updated_callables), - unique_function_identifier) + updated_callables = self.callables.copy() + updated_callables[unique_function_id] = new_clbl + renames[old_function_id] |= frozenset([unique_function_id]) - # }}} + new_entrypoints = self.new_entrypoints + if is_entrypoint: + new_entrypoints |= frozenset([unique_function_id]) - unique_function_identifier = function.name + return (self.copy(renames=renames, + callables=updated_callables, + new_entrypoints=new_entrypoints), + Variable(unique_function_id)) - while unique_function_identifier in self.callables: - unique_function_identifier = ( - next_indexed_function_identifier( - unique_function_identifier)) + def finish_program(self, program): + """ + Returns a copy of *program* with rollback renaming of the callables + done whenever possible. - updated_callables = self.callables.copy() - updated_callables[unique_function_identifier] = ( - in_kernel_callable) + For example: If all the ``sin`` function ids got diverged as + ``sin_0``, ``sin_1``, then all the renaming is done such that one of + flavors of the callable is renamed back to ``sin``. + """ + # FIXME: Generalize this if an inference happens over a proper subgraph + # of the callgraph (the following assert should be removed) + assert len(self.new_entrypoints) == len(program.entrypoints) - history[unique_function_identifier] = function.name + # {{{ get all the callables reachable from the new entrypoints. - return ( - self.copy( - history=history, - callables=updated_callables), - Variable(unique_function_identifier)) + # get the names of all callables reachable from the new entrypoints + new_callable_ids = _get_callable_ids(self.callables, self.new_entrypoints) - def finish_program(self, program, renamed_entrypoints): - """ - Returns a copy of *program* with renaming of the callables done whenever - needed. + # get the history of function ids from the performed renames: + history = {} + for old_func_id, new_func_ids in self.renames.items(): + for new_func_id in new_func_ids: + if new_func_id in (new_callable_ids | self.new_entrypoints): + history[new_func_id] = old_func_id - *For example: * If all the ``sin`` got diverged as ``sin_0, sin_1``, - then all the renaming is done such that one of flavors of the callable - is renamed back to ``sin``. + # }}} - :param renamed_entrypoints: A :class:`frozenset` of the names of the - renamed callable kernels which correspond to the entrypoints in - *self.callables_table*. - """ - assert len(renamed_entrypoints) == len(program.entrypoints) - new_callable_ids = _get_callable_ids(self.callables, renamed_entrypoints) + # AIM: Preserve the entrypoints of *program* - callees_with_entrypoint_names = (program.entrypoints & - new_callable_ids) - renamed_entrypoints + # If there are any callees having old entrypoint names => mark them for + # renaming + callees_with_old_entrypoint_names = ((program.entrypoints & new_callable_ids) + - self.new_entrypoints) - renames = {} + todo_renames = {} new_callables = {} - for c in callees_with_entrypoint_names: - unique_function_identifier = c + for c in callees_with_old_entrypoint_names: + unique_func_id = c - while unique_function_identifier in self.callables: - unique_function_identifier = ( - next_indexed_function_identifier( - unique_function_identifier)) + while unique_func_id in self.callables: + unique_func_id = next_indexed_function_id(unique_func_id) - renames[c] = unique_function_identifier + todo_renames[c] = unique_func_id - # we should perform a rewrite here. + for e in self.new_entrypoints: + # note renames to "rollback" the renaming of entrypoints + todo_renames[e] = history[e] + assert todo_renames[e] in program.entrypoints - for e in renamed_entrypoints: - renames[e] = self.history[e] - assert renames[e] in program.entrypoints + # try to rollback the names as much as possible + for new_id in new_callable_ids: + old_func_id = history[new_id] + if (isinstance(old_func_id, str) + and old_func_id not in set(todo_renames.values())): + todo_renames[new_id] = old_func_id - # {{{ calculate the renames needed + # {{{ perform the renames form todo_renames - for old_func_id in ((self.old_callable_ids-new_callable_ids) - - program.entrypoints): - # at this point we should not rename anything to the names of - # entrypoints - for new_func_id in (new_callable_ids-renames.keys()) & set( - self.history.keys()): - if old_func_id == self.history[new_func_id]: - renames[new_func_id] = old_func_id - break - # }}} + for func_id in (new_callable_ids | self.new_entrypoints): + clbl = self.callables[func_id] + if func_id in todo_renames: + assert history[func_id] == todo_renames[func_id] + func_id = todo_renames[func_id] + if isinstance(clbl, CallableKernel): + subknl = clbl.subkernel.copy(name=func_id) + subknl = rename_resolved_functions_in_a_single_kernel(subknl, + todo_renames) - for e in renamed_entrypoints: - new_subkernel = self.callables[e].subkernel.copy(name=self.history[e]) - new_subkernel = rename_resolved_functions_in_a_single_kernel( - new_subkernel, renames) - new_callables[self.history[e]] = self.callables[e].copy( - subkernel=new_subkernel) - - for func_id in new_callable_ids-renamed_entrypoints: - in_knl_callable = self.callables[func_id] - if isinstance(in_knl_callable, CallableKernel): - # if callable kernel, perform renames inside its expressions. - old_subkernel = in_knl_callable.subkernel - new_subkernel = rename_resolved_functions_in_a_single_kernel( - old_subkernel, renames) - in_knl_callable = ( - in_knl_callable.copy(subkernel=new_subkernel)) - elif isinstance(in_knl_callable, ScalarCallable): - pass - else: - raise NotImplementedError("Unknown callable type %s." % - type(in_knl_callable).__name__) + clbl = clbl.copy(subkernel=subknl) - if func_id in renames: - new_func_id = renames[func_id] - if isinstance(in_knl_callable, CallableKernel): - in_knl_callable = (in_knl_callable.copy( - subkernel=in_knl_callable.subkernel.copy( - name=new_func_id))) - new_callables[new_func_id] = in_knl_callable - else: - if isinstance(in_knl_callable, CallableKernel): - in_knl_callable = in_knl_callable.copy( - subkernel=in_knl_callable.subkernel.copy( - name=func_id)) - new_callables[func_id] = in_knl_callable + new_callables[func_id] = clbl + + # }}} return program.copy(callables_table=new_callables) @@ -775,7 +791,7 @@ def update_table(callables_table, clbl_id, clbl): return i, callables_table while clbl_id in callables_table: - clbl_id = next_indexed_function_identifier(clbl_id) + clbl_id = next_indexed_function_id(clbl_id) callables_table[clbl_id] = clbl diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 24df0ea15f2ba6fd3c388a69cde43710f54a0e54..36d5408a0c5faa24f5182919fdd563ab6cf80477 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -1018,17 +1018,14 @@ def infer_unknown_types(program, expect_completion=False): clbl_inf_ctx = make_clbl_inf_ctx(program.callables_table, program.entrypoints) - renamed_entrypoints = set() - for e in program.entrypoints: logger.debug(f"Entering entrypoint: {e}") arg_id_to_dtype = {arg.name: arg.dtype for arg in program[e].args if arg.dtype not in (None, auto)} new_callable, clbl_inf_ctx = program.callables_table[e].with_types( arg_id_to_dtype, clbl_inf_ctx) - clbl_inf_ctx, new_name = clbl_inf_ctx.with_callable(e, new_callable) - renamed_entrypoints.add(new_name.name) - + clbl_inf_ctx, new_name = clbl_inf_ctx.with_callable(e, new_callable, + is_entrypoint=True) if expect_completion: from loopy.types import LoopyType new_knl = new_callable.subkernel @@ -1048,7 +1045,7 @@ def infer_unknown_types(program, expect_completion=False): raise LoopyError("could not determine type of" f" '{vars_not_inferred.pop()}' of kernel '{e}'.") - return clbl_inf_ctx.finish_program(program, renamed_entrypoints) + return clbl_inf_ctx.finish_program(program) # }}}