From b7640ecf75669ef62f0a89a623bb60e9533e626a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 22 May 2021 18:44:24 -0500 Subject: [PATCH] translation unit minor fixes: - fix TranslationUnit.state when it does not contain any callable kernels - fix callables being "lost" during inference when they weren't reachable from the entrypoints --- loopy/translation_unit.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 269fd53f9..ac3b7076b 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -254,9 +254,11 @@ class TranslationUnit(ImmutableRecord): @property def state(self): """ Returns an instance of :class:`loopy.kernel.KernelState`. """ - return min(callable_knl.subkernel.state - for callable_knl in self.callables_table.values() - if isinstance(callable_knl, CallableKernel)) + from loopy.kernel import KernelState + return min((callable_knl.subkernel.state + for callable_knl in self.callables_table.values() + if isinstance(callable_knl, CallableKernel)), + default=KernelState.INITIAL) def with_kernel(self, kernel): """ @@ -662,7 +664,7 @@ class CallablesInferenceContext(ImmutableRecord): - self.new_entrypoints) todo_renames = {} - new_callables = {} + new_callables = dict(program.callables_table) for c in callees_with_old_entrypoint_names: unique_func_id = c @@ -808,7 +810,8 @@ def resolve_callables(program): callables_table = {} # callables: name of the calls seen in the program - callables = set(program.entrypoints) + callables = {name for name, clbl in program.callables_table.items() + if isinstance(clbl, CallableKernel)} while callables: clbl_name = callables.pop() -- GitLab