From e7c0908811ebe708d1eabf826b6d6845d23b1736 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Thu, 29 Apr 2021 09:24:14 -0500 Subject: [PATCH] define :meth:`InKernelCallable.get_called_callables` --- loopy/codegen/__init__.py | 6 +-- loopy/kernel/function_interface.py | 22 +++++++++ loopy/kernel/tools.py | 58 +++++++++++++++++++++++ loopy/preprocess.py | 7 +-- loopy/translation_unit.py | 76 ++++-------------------------- 5 files changed, 95 insertions(+), 74 deletions(-) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 14f0a75eb..86e18de34 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -605,11 +605,11 @@ def diverge_callee_entrypoints(program): If a :class:`loopy.kernel.function_interface.CallableKernel` is both an entrypoint and a callee, then rename the callee. """ - from loopy.translation_unit import (_get_reachable_callable_ids, + from loopy.translation_unit import (get_reachable_resolved_callable_ids, rename_resolved_functions_in_a_single_kernel) from pytools import UniqueNameGenerator - callable_ids = _get_reachable_callable_ids(program.callables_table, - program.entrypoints) + callable_ids = get_reachable_resolved_callable_ids(program.callables_table, + program.entrypoints) new_callables = {} todo_renames = {} diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 8c9a0f2ac..e4a91f1e7 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -313,6 +313,7 @@ class InKernelCallable(ImmutableRecord): .. automethod:: is_ready_for_codegen .. automethod:: get_hw_axes_sizes .. automethod:: get_used_hw_axes + .. automethod:: get_called_callables .. note:: @@ -481,6 +482,16 @@ class InKernelCallable(ImmutableRecord): """ raise NotImplementedError() + def get_called_callables(self, callables_table): + """ + Returns a :class:`frozenset` of callable ids called by *self* that are + resolved via *callables_table*. + + :arg callables_table: Similar to + :attr:`loopy.TranslationUnit.callables_table`. + """ + raise NotImplementedError + # }}} @@ -638,6 +649,12 @@ class ScalarCallable(InKernelCallable): def with_added_arg(self, arg_dtype, arg_descr): raise LoopyError("Cannot add args to scalar callables.") + def get_called_callables(self, callables_table): + """ + Returns a :class:`frozenset` of callable ids called by *self*. + """ + return frozenset() + # }}} @@ -927,6 +944,11 @@ class CallableKernel(InKernelCallable): return var(self.subkernel.name)(*tgt_parameters), False + def get_called_callables(self, callables_table): + from loopy.kernel.tools import get_resolved_callable_ids_called_by_knl + return get_resolved_callable_ids_called_by_knl(self.subkernel, + callables_table) + # }}} diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 19cb8acbd..8c12f1e35 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -36,6 +36,10 @@ from loopy.kernel import LoopKernel from loopy.translation_unit import (TranslationUnit, for_each_kernel) from loopy.kernel.function_interface import CallableKernel +from loopy.kernel.instruction import ( + MultiAssignmentBase, CInstruction, _DataObliviousInstruction) +from loopy.symbolic import CombineMapper +from functools import reduce import logging logger = logging.getLogger(__name__) @@ -1995,4 +1999,58 @@ def infer_args_are_input_output(kernel): # }}} + +# {{{ CallablesIDCollector + +class CallablesIDCollector(CombineMapper): + """ + Mapper to collect function identifiers of all resolved callables in an + expression. + """ + def combine(self, values): + import operator + return reduce(operator.or_, values, frozenset()) + + def map_resolved_function(self, expr): + return frozenset([expr.name]) + + def map_constant(self, expr): + return frozenset() + + def map_kernel(self, kernel): + callables_in_insn = frozenset() + + for insn in kernel.instructions: + if isinstance(insn, MultiAssignmentBase): + callables_in_insn = callables_in_insn | ( + self(insn.expression)) + elif isinstance(insn, (CInstruction, _DataObliviousInstruction)): + pass + else: + raise NotImplementedError(type(insn).__name__) + + for rule in kernel.substitutions.values(): + callables_in_insn = callables_in_insn | ( + self(rule.expression)) + + return callables_in_insn + + def map_type_cast(self, expr): + return self.rec(expr.child) + + map_variable = map_constant + map_function_symbol = map_constant + map_tagged_variable = map_constant + + +def get_resolved_callable_ids_called_by_knl(knl, callables): + clbl_id_collector = CallablesIDCollector() + callables_called_by_kernel = clbl_id_collector.map_kernel(knl) + callables_called_by_called_callables = frozenset().union(*( + callables[clbl_id].get_called_callables(callables) + for clbl_id in callables_called_by_kernel)) + return callables_called_by_kernel | callables_called_by_called_callables + +# }}} + # vim: foldmethod=marker diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 59e70827e..c28f14e80 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2363,9 +2363,10 @@ def inline_kernels_with_gbarriers(program): def filter_reachable_callables(t_unit): - from loopy.translation_unit import _get_reachable_callable_ids - reachable_function_ids = _get_reachable_callable_ids(t_unit.callables_table, - t_unit.entrypoints) + from loopy.translation_unit import get_reachable_resolved_callable_ids + reachable_function_ids = get_reachable_resolved_callable_ids(t_unit + .callables_table, + t_unit.entrypoints) new_callables = {name: clbl for name, clbl in t_unit.callables_table.items() if name in (reachable_function_ids | t_unit.entrypoints)} return t_unit.copy(callables_table=new_callables) diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 127e6341a..27c6392a5 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -27,18 +27,15 @@ from pymbolic.primitives import Variable from functools import wraps from loopy.symbolic import (RuleAwareIdentityMapper, ResolvedFunction, - CombineMapper, SubstitutionRuleMappingContext) + SubstitutionRuleMappingContext) from loopy.kernel.function_interface import ( CallableKernel, ScalarCallable) -from loopy.kernel.instruction import ( - MultiAssignmentBase, CInstruction, _DataObliviousInstruction) from loopy.diagnostic import LoopyError from loopy.library.reduction import ReductionOpFunction from loopy.kernel import LoopKernel from loopy.tools import update_persistent_hash from pymbolic.primitives import Call -from functools import reduce from pyrsistent import pmap, PMap __doc__ = """ @@ -411,70 +408,13 @@ def rename_resolved_functions_in_a_single_kernel(kernel, # }}} -# {{{ CallablesIDCollector - -class CallablesIDCollector(CombineMapper): +def get_reachable_resolved_callable_ids(callables, entrypoints): """ - Mapper to collect function identifiers of all resolved callables in an - expression. + Returns a :class:`frozenset` of callables ids that are resolved and + reachable from *entrypoints*. """ - def combine(self, values): - import operator - return reduce(operator.or_, values, frozenset()) - - def map_resolved_function(self, expr): - return frozenset([expr.name]) - - def map_constant(self, expr): - return frozenset() - - def map_kernel(self, kernel): - callables_in_insn = frozenset() - - for insn in kernel.instructions: - if isinstance(insn, MultiAssignmentBase): - callables_in_insn = callables_in_insn | ( - self(insn.expression)) - elif isinstance(insn, (CInstruction, _DataObliviousInstruction)): - pass - else: - raise NotImplementedError(type(insn).__name__) - - for rule in kernel.substitutions.values(): - callables_in_insn = callables_in_insn | ( - self(rule.expression)) - - return callables_in_insn - - def map_type_cast(self, expr): - return self.rec(expr.child) - - map_variable = map_constant - map_function_symbol = map_constant - map_tagged_variable = map_constant - - -def _get_reachable_callable_ids_for_knl(knl, callables): - clbl_id_collector = CallablesIDCollector() - - def rec(clbl_id): - clbl = callables[clbl_id] - if isinstance(clbl, CallableKernel): - return (_get_reachable_callable_ids_for_knl(clbl.subkernel, callables) - | frozenset([clbl_id])) - else: - return frozenset([clbl_id]) - - return frozenset().union(*(rec(clbl_id) - for clbl_id in clbl_id_collector.map_kernel(knl))) - - -def _get_reachable_callable_ids(callables, entrypoints): - return frozenset().union(*( - _get_reachable_callable_ids_for_knl(callables[e].subkernel, callables) - for e in entrypoints)) - -# }}} + return frozenset().union(*(callables[e].get_called_callables(callables) + for e in entrypoints)) # {{{ CallablesInferenceContext @@ -631,8 +571,8 @@ class CallablesInferenceContext(ImmutableRecord): # {{{ get all the callables reachable from the new entrypoints. # get the names of all callables reachable from the new entrypoints - new_callable_ids = _get_reachable_callable_ids( - self.callables, self.new_entrypoints) + new_callable_ids = get_reachable_resolved_callable_ids(self.callables, + self.new_entrypoints) # get the history of function ids from the performed renames: history = {} -- GitLab