From 9dfb4e243dc4488824fbf940a7e37302f3df39a0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 26 Mar 2022 22:30:39 -0500 Subject: [PATCH] make reverse_graph emit immutable objects --- doc/conf.py | 3 +++ pytato/transform.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index d81510b..2991a2f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -26,6 +26,7 @@ intersphinx_mapping = { "https://documen.tician.de/loopy/": None, "https://documen.tician.de/sumpy/": None, "https://documen.tician.de/islpy/": None, + "https://pyrsistent.readthedocs.io/en/latest/": None, } import sys @@ -33,4 +34,6 @@ sys.PYTATO_BUILDING_SPHINX_DOCS = True nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], + ["py:class", r"pyrsistent.typing.(.+)"], + ] diff --git a/pytato/transform.py b/pytato/transform.py index d38edc2..480a377 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -44,6 +44,8 @@ from pytato.array import ( from pytato.loopy import LoopyCall, LoopyCallResult from dataclasses import dataclass from pytato.tags import ImplStored +from pyrsistent import pmap +from pyrsistent.typing import PMap as PMapT if TYPE_CHECKING: from pytato.distributed import DistributedSendRefHolder, DistributedRecv @@ -1255,15 +1257,16 @@ def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, # {{{ operations on graphs in dict form -def reverse_graph(graph: Dict[ArrayOrNames, Set[ArrayOrNames]]) \ - -> Dict[ArrayOrNames, Set[ArrayOrNames]]: - """Reverses a graph. +def reverse_graph(graph: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]] + ) -> PMapT[ArrayOrNames, FrozenSet[ArrayOrNames]]: + """ + Reverses a graph. :param graph: A :class:`dict` representation of a directed graph, mapping each node to other nodes to which it is connected by edges. A possible use case for this function is the graph in :attr:`UsersCollector.node_to_users`. - :returns: A :class:`dict` representing *graph* with edges reversed. + :returns: A :class:`pyrsistent.PMap` representing *graph* with edges reversed. """ result: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} @@ -1274,7 +1277,7 @@ def reverse_graph(graph: Dict[ArrayOrNames, Set[ArrayOrNames]]) \ for other_node_key in edges: result.setdefault(other_node_key, set()).add(node_key) - return result + return pmap({k: frozenset(v) for k, v in result.items()}) def _recursively_get_all_users( -- GitLab