diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 2c663aa35a752b3fe3b4e54b7b16d2fa0b167b38..f690d99566e7380cdd83995b62da6866e1abb5ea 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -59,7 +59,6 @@ def _is_a_reduction_op(expr): if isinstance(expr, ResolvedFunction): return _is_a_reduction_op(expr.function) - from loopy.library.reduction import ReductionOpFunction return isinstance(expr, ReductionOpFunction) @@ -458,11 +457,16 @@ class CallablesIDCollector(CombineMapper): def _get_reachable_callable_ids_for_knl(knl, callables): clbl_id_collector = CallablesIDCollector() - return frozenset().union(*( - _get_reachable_callable_ids_for_knl(callables[clbl].subkernel, callables) | - frozenset([clbl]) if isinstance(callables[clbl], CallableKernel) else - frozenset([clbl]) - for clbl in clbl_id_collector.map_kernel(knl))) + def rec(clbl_id): + clbl = callables[clbl_id] + if isinstance(clbl, CallableKernel): + return (_get_reachable_callable_ids_for_knl(clbl.subkernel, callables) + | frozenset([clbl_id])) + else: + return frozenset([clbl_id]) + + return frozenset().union(*(rec(clbl_id) + for clbl_id in clbl_id_collector.map_kernel(knl))) def _get_reachable_callable_ids(callables, entrypoints): @@ -475,6 +479,7 @@ def _get_reachable_callable_ids(callables, entrypoints): # {{{ CallablesInferenceContext + def get_all_subst_names(callables): """ Returns a :class:`set` of all substitution rule names in the callable