Skip to content
Snippets Groups Projects
Commit 256a357a authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

Add pt.analysis.get_num_call_sites

parent edac8c3b
No related branches found
No related tags found
No related merge requests found
...@@ -31,9 +31,11 @@ from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, ...@@ -31,9 +31,11 @@ from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
DictOfNamedArrays, NamedArray, DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase, IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType) ShapeType)
from pytato.function import FunctionDefinition, Call
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.loopy import LoopyCall from pytato.loopy import LoopyCall
from pymbolic.mapper.optimize import optimize_mapper from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method
if TYPE_CHECKING: if TYPE_CHECKING:
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
...@@ -47,6 +49,8 @@ __doc__ = """ ...@@ -47,6 +49,8 @@ __doc__ = """
.. autofunction:: get_num_nodes .. autofunction:: get_num_nodes
.. autofunction:: get_num_call_sites
.. autoclass:: DirectPredecessorsGetter .. autoclass:: DirectPredecessorsGetter
""" """
...@@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: ...@@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
return ncm.count return ncm.count
# }}} # }}}
# {{{ CallSiteCountMapper
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class CallSiteCountMapper(CachedWalkMapper):
"""
Counts the number of :class:`~pytato.Call` nodes in a DAG.
.. attribute:: count
The number of nodes.
"""
def __init__(self) -> None:
super().__init__()
self.count = 0
# type-ignore-reason: dropped the extra `*args, **kwargs`.
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
return id(expr)
@memoize_method
def map_function_definition(self, /, expr: FunctionDefinition,
*args: Any, **kwargs: Any) -> None:
if not self.visit(expr):
return
new_mapper = self.clone_for_callee()
for subexpr in expr.returns.values():
new_mapper(subexpr, *args, **kwargs)
self.count += new_mapper.count
self.post_visit(expr, *args, **kwargs)
# type-ignore-reason: dropped the extra `*args, **kwargs`.
def post_visit(self, expr: Any) -> None: # type: ignore[override]
if isinstance(expr, Call):
self.count += 1
def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Returns the number of nodes in DAG *outputs*."""
from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)
cscm = CallSiteCountMapper()
cscm(outputs)
return cscm.count
# }}}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment