diff --git a/doc/conf.py b/doc/conf.py index d81510bb019b2a0feb23ebffef0d29739066d1a1..2991a2f4b73a6fbb8708a57fce8ce6a4d9cd84a0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -26,6 +26,7 @@ intersphinx_mapping = { "https://documen.tician.de/loopy/": None, "https://documen.tician.de/sumpy/": None, "https://documen.tician.de/islpy/": None, + "https://pyrsistent.readthedocs.io/en/latest/": None, } import sys @@ -33,4 +34,6 @@ sys.PYTATO_BUILDING_SPHINX_DOCS = True nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], + ["py:class", r"pyrsistent.typing.(.+)"], + ] diff --git a/pytato/transform.py b/pytato/transform.py index d38edc2c72b9018ae34c113645ab92f0316e2ab4..480a377eca1887c070a8c223cfeac5fb015387a6 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -44,6 +44,8 @@ from pytato.array import ( from pytato.loopy import LoopyCall, LoopyCallResult from dataclasses import dataclass from pytato.tags import ImplStored +from pyrsistent import pmap +from pyrsistent.typing import PMap as PMapT if TYPE_CHECKING: from pytato.distributed import DistributedSendRefHolder, DistributedRecv @@ -1255,15 +1257,16 @@ def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, # {{{ operations on graphs in dict form -def reverse_graph(graph: Dict[ArrayOrNames, Set[ArrayOrNames]]) \ - -> Dict[ArrayOrNames, Set[ArrayOrNames]]: - """Reverses a graph. +def reverse_graph(graph: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]] + ) -> PMapT[ArrayOrNames, FrozenSet[ArrayOrNames]]: + """ + Reverses a graph. :param graph: A :class:`dict` representation of a directed graph, mapping each node to other nodes to which it is connected by edges. A possible use case for this function is the graph in :attr:`UsersCollector.node_to_users`. - :returns: A :class:`dict` representing *graph* with edges reversed. + :returns: A :class:`pyrsistent.PMap` representing *graph* with edges reversed. """ result: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} @@ -1274,7 +1277,7 @@ def reverse_graph(graph: Dict[ArrayOrNames, Set[ArrayOrNames]]) \ for other_node_key in edges: result.setdefault(other_node_key, set()).add(node_key) - return result + return pmap({k: frozenset(v) for k, v in result.items()}) def _recursively_get_all_users(