diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 6fba7c56090707781a8732127efe8705d0d7b82c..a0122576a9910292481ec0ce3a8b09da34a9198b 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -2010,4 +2010,43 @@ def get_resolved_callable_ids_called_by_knl(knl, callables, recursive=True): # }}} + +# {{{ get_call_graph + +def get_call_graph(t_unit, only_kernel_callables=False): + """ + Returns a mapping from a callable name to the calls seen in it. + + :arg t_unit: An instance of :class:`TranslationUnit`. + """ + from pyrsistent import pmap + from loopy.kernel import KernelState + + if t_unit.state < KernelState.CALLS_RESOLVED: + raise LoopyError("TranslationUnit must have calls resolved in order to" + " compute its call graph.") + + knl_callables = frozenset(name for name, clbl in t_unit.callables_table.items() + if isinstance(clbl, CallableKernel)) + + # stores a mapping from caller -> "direct"" callees + call_graph = {} + + for name, clbl in t_unit.callables_table.items(): + if (not isinstance(clbl, CallableKernel) + and only_kernel_callables): + pass + else: + if only_kernel_callables: + call_graph[name] = (clbl.get_called_callables(t_unit.callables_table, + recursive=False) + & knl_callables) + else: + call_graph[name] = clbl.get_called_callables(t_unit.callables_table, + recursive=False) + + return pmap(call_graph) + +# }}} + # vim: foldmethod=marker