From e7c0908811ebe708d1eabf826b6d6845d23b1736 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Thu, 29 Apr 2021 09:24:14 -0500
Subject: [PATCH] define :meth:`InKernelCallable.get_called_callables`

---
 loopy/codegen/__init__.py          |  6 +--
 loopy/kernel/function_interface.py | 22 +++++++++
 loopy/kernel/tools.py              | 58 +++++++++++++++++++++++
 loopy/preprocess.py                |  7 +--
 loopy/translation_unit.py          | 76 ++++--------------------------
 5 files changed, 95 insertions(+), 74 deletions(-)

diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 14f0a75eb..86e18de34 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -605,11 +605,11 @@ def diverge_callee_entrypoints(program):
     If a :class:`loopy.kernel.function_interface.CallableKernel` is both an
     entrypoint and a callee, then rename the callee.
     """
-    from loopy.translation_unit import (_get_reachable_callable_ids,
+    from loopy.translation_unit import (get_reachable_resolved_callable_ids,
                                         rename_resolved_functions_in_a_single_kernel)
     from pytools import UniqueNameGenerator
-    callable_ids = _get_reachable_callable_ids(program.callables_table,
-                                     program.entrypoints)
+    callable_ids = get_reachable_resolved_callable_ids(program.callables_table,
+                                                       program.entrypoints)
 
     new_callables = {}
     todo_renames = {}
diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py
index 8c9a0f2ac..e4a91f1e7 100644
--- a/loopy/kernel/function_interface.py
+++ b/loopy/kernel/function_interface.py
@@ -313,6 +313,7 @@ class InKernelCallable(ImmutableRecord):
     .. automethod:: is_ready_for_codegen
     .. automethod:: get_hw_axes_sizes
     .. automethod:: get_used_hw_axes
+    .. automethod:: get_called_callables
 
     .. note::
 
@@ -481,6 +482,16 @@ class InKernelCallable(ImmutableRecord):
         """
         raise NotImplementedError()
 
+    def get_called_callables(self, callables_table):
+        """
+        Returns a :class:`frozenset` of callable ids called by *self* that are
+        resolved via *callables_table*.
+
+        :arg callables_table: Similar to
+            :attr:`loopy.TranslationUnit.callables_table`.
+        """
+        raise NotImplementedError
+
 # }}}
 
 
@@ -638,6 +649,12 @@ class ScalarCallable(InKernelCallable):
     def with_added_arg(self, arg_dtype, arg_descr):
         raise LoopyError("Cannot add args to scalar callables.")
 
+    def get_called_callables(self, callables_table):
+        """
+        Returns a :class:`frozenset` of callable ids called by *self*.
+        """
+        return frozenset()
+
 # }}}
 
 
@@ -927,6 +944,11 @@ class CallableKernel(InKernelCallable):
 
         return var(self.subkernel.name)(*tgt_parameters), False
 
+    def get_called_callables(self, callables_table):
+        from loopy.kernel.tools import get_resolved_callable_ids_called_by_knl
+        return get_resolved_callable_ids_called_by_knl(self.subkernel,
+                                                       callables_table)
+
 # }}}
 
 
diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index 19cb8acbd..8c12f1e35 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -36,6 +36,10 @@ from loopy.kernel import LoopKernel
 from loopy.translation_unit import (TranslationUnit,
                                     for_each_kernel)
 from loopy.kernel.function_interface import CallableKernel
+from loopy.kernel.instruction import (
+        MultiAssignmentBase, CInstruction, _DataObliviousInstruction)
+from loopy.symbolic import CombineMapper
+from functools import reduce
 import logging
 logger = logging.getLogger(__name__)
 
@@ -1995,4 +1999,58 @@ def infer_args_are_input_output(kernel):
 
 # }}}
 
+
+# {{{ CallablesIDCollector
+
+class CallablesIDCollector(CombineMapper):
+    """
+    Mapper to collect function identifiers of all resolved callables in an
+    expression.
+    """
+    def combine(self, values):
+        import operator
+        return reduce(operator.or_, values, frozenset())
+
+    def map_resolved_function(self, expr):
+        return frozenset([expr.name])
+
+    def map_constant(self, expr):
+        return frozenset()
+
+    def map_kernel(self, kernel):
+        callables_in_insn = frozenset()
+
+        for insn in kernel.instructions:
+            if isinstance(insn, MultiAssignmentBase):
+                callables_in_insn = callables_in_insn | (
+                        self(insn.expression))
+            elif isinstance(insn, (CInstruction, _DataObliviousInstruction)):
+                pass
+            else:
+                raise NotImplementedError(type(insn).__name__)
+
+        for rule in kernel.substitutions.values():
+            callables_in_insn = callables_in_insn | (
+                    self(rule.expression))
+
+        return callables_in_insn
+
+    def map_type_cast(self, expr):
+        return self.rec(expr.child)
+
+    map_variable = map_constant
+    map_function_symbol = map_constant
+    map_tagged_variable = map_constant
+
+
+def get_resolved_callable_ids_called_by_knl(knl, callables):
+    clbl_id_collector = CallablesIDCollector()
+    callables_called_by_kernel = clbl_id_collector.map_kernel(knl)
+    callables_called_by_called_callables = frozenset().union(*(
+        callables[clbl_id].get_called_callables(callables)
+        for clbl_id in callables_called_by_kernel))
+    return callables_called_by_kernel | callables_called_by_called_callables
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 59e70827e..c28f14e80 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -2363,9 +2363,10 @@ def inline_kernels_with_gbarriers(program):
 
 
 def filter_reachable_callables(t_unit):
-    from loopy.translation_unit import _get_reachable_callable_ids
-    reachable_function_ids = _get_reachable_callable_ids(t_unit.callables_table,
-                                               t_unit.entrypoints)
+    from loopy.translation_unit import get_reachable_resolved_callable_ids
+    reachable_function_ids = get_reachable_resolved_callable_ids(t_unit
+                                                                 .callables_table,
+                                                                 t_unit.entrypoints)
     new_callables = {name: clbl for name, clbl in t_unit.callables_table.items()
                      if name in (reachable_function_ids | t_unit.entrypoints)}
     return t_unit.copy(callables_table=new_callables)
diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py
index 127e6341a..27c6392a5 100644
--- a/loopy/translation_unit.py
+++ b/loopy/translation_unit.py
@@ -27,18 +27,15 @@ from pymbolic.primitives import Variable
 from functools import wraps
 
 from loopy.symbolic import (RuleAwareIdentityMapper, ResolvedFunction,
-        CombineMapper, SubstitutionRuleMappingContext)
+                            SubstitutionRuleMappingContext)
 from loopy.kernel.function_interface import (
         CallableKernel, ScalarCallable)
-from loopy.kernel.instruction import (
-        MultiAssignmentBase, CInstruction, _DataObliviousInstruction)
 from loopy.diagnostic import LoopyError
 from loopy.library.reduction import ReductionOpFunction
 
 from loopy.kernel import LoopKernel
 from loopy.tools import update_persistent_hash
 from pymbolic.primitives import Call
-from functools import reduce
 from pyrsistent import pmap, PMap
 
 __doc__ = """
@@ -411,70 +408,13 @@ def rename_resolved_functions_in_a_single_kernel(kernel,
 # }}}
 
 
-# {{{ CallablesIDCollector
-
-class CallablesIDCollector(CombineMapper):
+def get_reachable_resolved_callable_ids(callables, entrypoints):
     """
-    Mapper to collect function identifiers of all resolved callables in an
-    expression.
+    Returns a :class:`frozenset` of callables ids that are resolved and
+    reachable from *entrypoints*.
     """
-    def combine(self, values):
-        import operator
-        return reduce(operator.or_, values, frozenset())
-
-    def map_resolved_function(self, expr):
-        return frozenset([expr.name])
-
-    def map_constant(self, expr):
-        return frozenset()
-
-    def map_kernel(self, kernel):
-        callables_in_insn = frozenset()
-
-        for insn in kernel.instructions:
-            if isinstance(insn, MultiAssignmentBase):
-                callables_in_insn = callables_in_insn | (
-                        self(insn.expression))
-            elif isinstance(insn, (CInstruction, _DataObliviousInstruction)):
-                pass
-            else:
-                raise NotImplementedError(type(insn).__name__)
-
-        for rule in kernel.substitutions.values():
-            callables_in_insn = callables_in_insn | (
-                    self(rule.expression))
-
-        return callables_in_insn
-
-    def map_type_cast(self, expr):
-        return self.rec(expr.child)
-
-    map_variable = map_constant
-    map_function_symbol = map_constant
-    map_tagged_variable = map_constant
-
-
-def _get_reachable_callable_ids_for_knl(knl, callables):
-    clbl_id_collector = CallablesIDCollector()
-
-    def rec(clbl_id):
-        clbl = callables[clbl_id]
-        if isinstance(clbl, CallableKernel):
-            return (_get_reachable_callable_ids_for_knl(clbl.subkernel, callables)
-                    | frozenset([clbl_id]))
-        else:
-            return frozenset([clbl_id])
-
-    return frozenset().union(*(rec(clbl_id)
-                               for clbl_id in clbl_id_collector.map_kernel(knl)))
-
-
-def _get_reachable_callable_ids(callables, entrypoints):
-    return frozenset().union(*(
-        _get_reachable_callable_ids_for_knl(callables[e].subkernel, callables)
-        for e in entrypoints))
-
-# }}}
+    return frozenset().union(*(callables[e].get_called_callables(callables)
+                               for e in entrypoints))
 
 
 # {{{ CallablesInferenceContext
@@ -631,8 +571,8 @@ class CallablesInferenceContext(ImmutableRecord):
         # {{{ get all the callables reachable from the new entrypoints.
 
         # get the names of all callables reachable from the new entrypoints
-        new_callable_ids = _get_reachable_callable_ids(
-                self.callables, self.new_entrypoints)
+        new_callable_ids = get_reachable_resolved_callable_ids(self.callables,
+                                                               self.new_entrypoints)
 
         # get the history of function ids from the performed renames:
         history = {}
-- 
GitLab