diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3b21b6e1512c3064b9f9d429b5043fd9e260c897..7f6a93da5e66ebe6fba8a4244526eadf0cd2bc50 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -25,13 +25,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Mapping, Dict, Union, Set, Tuple, Any +from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, + TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, - IndexBase, IndexRemappingBase, InputArgumentBase) + IndexBase, IndexRemappingBase, InputArgumentBase, + ShapeType) from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper from pytato.loopy import LoopyCall +if TYPE_CHECKING: + from pytato.distributed import DistributedRecv, DistributedSendRefHolder + __doc__ = """ .. currentmodule:: pytato.analysis @@ -40,6 +45,8 @@ __doc__ = """ .. autofunction:: is_einsum_similar_to_subscript .. autofunction:: get_num_nodes + +.. autoclass:: DirectPredecessorsGetter """ @@ -267,6 +274,83 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: return True +# {{{ DirectPredecessorsGetter + +class DirectPredecessorsGetter(Mapper): + """ + Mapper to get the + `direct predecessors + `__ + of a node. + + .. note:: + + We only consider the predecessors of a nodes in a data-flow sense. + """ + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenSet[Array]: + return frozenset({dim for dim in shape if isinstance(dim, Array)}) + + def map_index_lambda(self, expr: IndexLambda) -> FrozenSet[Array]: + return (frozenset(expr.bindings.values()) + | self._get_preds_from_shape(expr.shape)) + + def map_stack(self, expr: Stack) -> FrozenSet[Array]: + return (frozenset(expr.arrays) + | self._get_preds_from_shape(expr.shape)) + + def map_concatenate(self, expr: Concatenate) -> FrozenSet[Array]: + return (frozenset(expr.arrays) + | self._get_preds_from_shape(expr.shape)) + + def map_einsum(self, expr: Einsum) -> FrozenSet[Array]: + return (frozenset(expr.args) + | self._get_preds_from_shape(expr.shape)) + + def map_loopy_call_result(self, expr: NamedArray) -> FrozenSet[Array]: + from pytato.loopy import LoopyCallResult, LoopyCall + assert isinstance(expr, LoopyCallResult) + assert isinstance(expr._container, LoopyCall) + return (frozenset(ary + for ary in expr._container.bindings.values() + if isinstance(ary, Array)) + | self._get_preds_from_shape(expr.shape)) + + def _map_index_base(self, expr: IndexBase) -> FrozenSet[Array]: + return (frozenset([expr.array]) + | frozenset(idx for idx in expr.indices + if isinstance(idx, Array)) + | self._get_preds_from_shape(expr.shape)) + + map_basic_index = _map_index_base + map_contiguous_advanced_index = _map_index_base + map_non_contiguous_advanced_index = _map_index_base + + def _map_index_remapping_base(self, expr: IndexRemappingBase + ) -> FrozenSet[Array]: + return frozenset([expr.array]) + + map_roll = _map_index_remapping_base + map_axis_permutation = _map_index_remapping_base + map_reshape = _map_index_remapping_base + + def _map_input_base(self, expr: InputArgumentBase) -> FrozenSet[Array]: + return self._get_preds_from_shape(expr.shape) + + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + map_size_param = _map_input_base + + def map_distributed_recv(self, expr: DistributedRecv) -> FrozenSet[Array]: + return self._get_preds_from_shape(expr.shape) + + def map_distributed_send_ref_holder(self, + expr: DistributedSendRefHolder + ) -> FrozenSet[Array]: + return frozenset([expr.passthrough_data]) + +# }}} + + # {{{ NodeCountMapper class NodeCountMapper(CachedWalkMapper):