From d82ada250f4a5db0fa3c864db7bc3a873477c7d9 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Tue, 27 Dec 2022 20:22:47 +0100 Subject: [PATCH] Remove unused reverse_graph (#376) * Deprecate reverse_graph in favor of pytools version * remove it and clean up tests * merge cleanup * more comprehensive tests * minor UsersCollector fixes, remove nonfunctional test part UsersCollector and NUsersCollector have slightly different behavior regarding duplicates (ignore in UsersCollector but counted in NUsersCollector), hence the previous test failed. * fix flake8 * Update __init__.py * restore reverse_graph in test * add progress indicator comment Co-authored-by: Andreas Kloeckner <inform@tiker.net> --- pytato/transform/__init__.py | 25 ------------------------- setup.py | 2 +- test/test_pytato.py | 19 +++++++++++++++++-- 3 files changed, 18 insertions(+), 28 deletions(-) mode change 100755 => 100644 test/test_pytato.py diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 32c399d..3d3e1a3 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,7 +46,6 @@ from pytato.distributed.nodes import ( from pytato.loopy import LoopyCall, LoopyCallResult from dataclasses import dataclass from pytato.tags import ImplStored -from immutables import Map from pymbolic.mapper.optimize import optimize_mapper ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] @@ -85,7 +84,6 @@ Dict representation of DAGs --------------------------- .. autoclass:: UsersCollector -.. autofunction:: reverse_graph .. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes @@ -1468,29 +1466,6 @@ def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, # {{{ operations on graphs in dict form -def reverse_graph(graph: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]] - ) -> Map[ArrayOrNames, FrozenSet[ArrayOrNames]]: - """ - Reverses a graph. - - :param graph: A :class:`dict` representation of a directed graph, mapping each - node to other nodes to which it is connected by edges. A possible - use case for this function is the graph in - :attr:`UsersCollector.node_to_users`. - :returns: A :class:`immutables.Map` representing *graph* with edges reversed. - """ - result: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} - - for node_key, edges in graph.items(): - # Make sure every node is in the result even if it has no users - result.setdefault(node_key, set()) - - for other_node_key in edges: - result.setdefault(other_node_key, set()).add(node_key) - - return Map({k: frozenset(v) for k, v in result.items()}) - - def _recursively_get_all_users( direct_users: Mapping[ArrayOrNames, Set[ArrayOrNames]], node: ArrayOrNames) -> FrozenSet[ArrayOrNames]: diff --git a/setup.py b/setup.py index ed32405..e00ab52 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ setup( python_requires="~=3.8", install_requires=[ "loopy>=2020.2", - "pytools>=2021.1", + "pytools>=2022.1.13", "immutables", "attrs", "bidict", diff --git a/test/test_pytato.py b/test/test_pytato.py old mode 100755 new mode 100644 index adfa482..f94aee9 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -324,7 +324,10 @@ def test_toposortmapper(): def test_userscollector(): from testlib import RandomDAGContext, make_random_dag - from pytato.transform import UsersCollector, reverse_graph + from pytato.transform import UsersCollector + from pytato.analysis import get_nusers + + from pytools.graph import reverse_graph # Check that nodes without users are correctly reversed array = pt.make_placeholder(name="array", shape=1, dtype=np.int64) @@ -332,15 +335,22 @@ def test_userscollector(): uc = UsersCollector() uc(y) + rev_graph = reverse_graph(uc.node_to_users) rev_graph2 = reverse_graph(reverse_graph(rev_graph)) assert dict(reverse_graph(rev_graph2)) == uc.node_to_users + assert len(uc.node_to_users) == 2 + assert uc.node_to_users[y] == set() + assert uc.node_to_users[array].pop() == y + assert len(uc.node_to_users[array]) == 0 + # Test random DAGs axis_len = 5 for i in range(100): + print(i) # progress indicator rdagc = RandomDAGContext(np.random.default_rng(seed=i), axis_len=axis_len, use_numpy=False) @@ -351,9 +361,14 @@ def test_userscollector(): rev_graph = reverse_graph(uc.node_to_users) rev_graph2 = reverse_graph(reverse_graph(rev_graph)) - assert rev_graph2 == rev_graph + nuc = get_nusers(dag) + + assert len(uc.node_to_users) == len(nuc)+1 + assert uc.node_to_users[dag] == set() + assert nuc[dag] == 0 + def test_asciidag(): pytest.importorskip("asciidag") -- GitLab