diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 93dd745ee6aef687dc46004b434f414e6dfc0fe5..14f0a75ebc6f7a1b4880b1b68bc56cf23d3fad7b 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -605,10 +605,10 @@ def diverge_callee_entrypoints(program): If a :class:`loopy.kernel.function_interface.CallableKernel` is both an entrypoint and a callee, then rename the callee. """ - from loopy.translation_unit import (_get_callable_ids, + from loopy.translation_unit import (_get_reachable_callable_ids, rename_resolved_functions_in_a_single_kernel) from pytools import UniqueNameGenerator - callable_ids = _get_callable_ids(program.callables_table, + callable_ids = _get_reachable_callable_ids(program.callables_table, program.entrypoints) new_callables = {} diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 87b42160de68f898c49df957b7bb754726d40815..59e70827edaf4bd21cfe0f975f1b8e7b02797ed1 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2363,8 +2363,8 @@ def inline_kernels_with_gbarriers(program): def filter_reachable_callables(t_unit): - from loopy.translation_unit import _get_callable_ids - reachable_function_ids = _get_callable_ids(t_unit.callables_table, + from loopy.translation_unit import _get_reachable_callable_ids + reachable_function_ids = _get_reachable_callable_ids(t_unit.callables_table, t_unit.entrypoints) new_callables = {name: clbl for name, clbl in t_unit.callables_table.items() if name in (reachable_function_ids | t_unit.entrypoints)} diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index ac3b7076b852a8633276cd5468ecf6c8e4aefeb0..d76e6f9fea63c704bb06dcab5edc109145e27d15 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -488,19 +488,19 @@ class CallablesIDCollector(CombineMapper): map_type_cast = map_constant -def _get_callable_ids_for_knl(knl, callables): +def _get_reachable_callable_ids_for_knl(knl, callables): clbl_id_collector = CallablesIDCollector() return frozenset().union(*( - _get_callable_ids_for_knl(callables[clbl].subkernel, callables) | + _get_reachable_callable_ids_for_knl(callables[clbl].subkernel, callables) | frozenset([clbl]) if isinstance(callables[clbl], CallableKernel) else frozenset([clbl]) for clbl in clbl_id_collector.map_kernel(knl))) -def _get_callable_ids(callables, entrypoints): +def _get_reachable_callable_ids(callables, entrypoints): return frozenset().union(*( - _get_callable_ids_for_knl(callables[e].subkernel, callables) + _get_reachable_callable_ids_for_knl(callables[e].subkernel, callables) for e in entrypoints)) # }}} @@ -645,7 +645,8 @@ class CallablesInferenceContext(ImmutableRecord): # {{{ get all the callables reachable from the new entrypoints. # get the names of all callables reachable from the new entrypoints - new_callable_ids = _get_callable_ids(self.callables, self.new_entrypoints) + new_callable_ids = _get_reachable_callable_ids( + self.callables, self.new_entrypoints) # get the history of function ids from the performed renames: history = {}