diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a3c5e99ced0b03de587260e3856f0312db14d12b..f279a35244ec90eeb035198dea93231d82042de2 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -31,9 +31,11 @@ from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, ShapeType) +from pytato.function import FunctionDefinition, Call from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder @@ -47,6 +49,8 @@ __doc__ = """ .. autofunction:: get_num_nodes +.. autofunction:: get_num_call_sites + .. autoclass:: DirectPredecessorsGetter """ @@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: return ncm.count # }}} + + +# {{{ CallSiteCountMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class CallSiteCountMapper(CachedWalkMapper): + """ + Counts the number of :class:`~pytato.Call` nodes in a DAG. + + .. attribute:: count + + The number of nodes. + """ + + def __init__(self) -> None: + super().__init__() + self.count = 0 + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + return id(expr) + + @memoize_method + def map_function_definition(self, /, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee() + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self.count += new_mapper.count + + self.post_visit(expr, *args, **kwargs) + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] + if isinstance(expr, Call): + self.count += 1 + + +def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Returns the number of nodes in DAG *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + cscm = CallSiteCountMapper() + cscm(outputs) + + return cscm.count + +# }}}