diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 32c399d408565902eace21607d1d317858808b3a..3d3e1a34bed17973f7cc36927c379c7039e84843 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,7 +46,6 @@ from pytato.distributed.nodes import ( from pytato.loopy import LoopyCall, LoopyCallResult from dataclasses import dataclass from pytato.tags import ImplStored -from immutables import Map from pymbolic.mapper.optimize import optimize_mapper ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] @@ -85,7 +84,6 @@ Dict representation of DAGs --------------------------- .. autoclass:: UsersCollector -.. autofunction:: reverse_graph .. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes @@ -1468,29 +1466,6 @@ def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, # {{{ operations on graphs in dict form -def reverse_graph(graph: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]] - ) -> Map[ArrayOrNames, FrozenSet[ArrayOrNames]]: - """ - Reverses a graph. - - :param graph: A :class:`dict` representation of a directed graph, mapping each - node to other nodes to which it is connected by edges. A possible - use case for this function is the graph in - :attr:`UsersCollector.node_to_users`. - :returns: A :class:`immutables.Map` representing *graph* with edges reversed. - """ - result: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} - - for node_key, edges in graph.items(): - # Make sure every node is in the result even if it has no users - result.setdefault(node_key, set()) - - for other_node_key in edges: - result.setdefault(other_node_key, set()).add(node_key) - - return Map({k: frozenset(v) for k, v in result.items()}) - - def _recursively_get_all_users( direct_users: Mapping[ArrayOrNames, Set[ArrayOrNames]], node: ArrayOrNames) -> FrozenSet[ArrayOrNames]: diff --git a/setup.py b/setup.py index ed32405e697839a4c451a62e5b0aba4a6572759a..e00ab52548e0fb0c3ddb775a913b2f01cba12e21 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ setup( python_requires="~=3.8", install_requires=[ "loopy>=2020.2", - "pytools>=2021.1", + "pytools>=2022.1.13", "immutables", "attrs", "bidict", diff --git a/test/test_pytato.py b/test/test_pytato.py old mode 100755 new mode 100644 index adfa48242e2d4660b87a95c35907a1df5d589e03..f94aee98525dbd34b1950d94985ddbdd4618c66a --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -324,7 +324,10 @@ def test_toposortmapper(): def test_userscollector(): from testlib import RandomDAGContext, make_random_dag - from pytato.transform import UsersCollector, reverse_graph + from pytato.transform import UsersCollector + from pytato.analysis import get_nusers + + from pytools.graph import reverse_graph # Check that nodes without users are correctly reversed array = pt.make_placeholder(name="array", shape=1, dtype=np.int64) @@ -332,15 +335,22 @@ def test_userscollector(): uc = UsersCollector() uc(y) + rev_graph = reverse_graph(uc.node_to_users) rev_graph2 = reverse_graph(reverse_graph(rev_graph)) assert dict(reverse_graph(rev_graph2)) == uc.node_to_users + assert len(uc.node_to_users) == 2 + assert uc.node_to_users[y] == set() + assert uc.node_to_users[array].pop() == y + assert len(uc.node_to_users[array]) == 0 + # Test random DAGs axis_len = 5 for i in range(100): + print(i) # progress indicator rdagc = RandomDAGContext(np.random.default_rng(seed=i), axis_len=axis_len, use_numpy=False) @@ -351,9 +361,14 @@ def test_userscollector(): rev_graph = reverse_graph(uc.node_to_users) rev_graph2 = reverse_graph(reverse_graph(rev_graph)) - assert rev_graph2 == rev_graph + nuc = get_nusers(dag) + + assert len(uc.node_to_users) == len(nuc)+1 + assert uc.node_to_users[dag] == set() + assert nuc[dag] == 0 + def test_asciidag(): pytest.importorskip("asciidag")