diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py
index 32c399d408565902eace21607d1d317858808b3a..3d3e1a34bed17973f7cc36927c379c7039e84843 100644
--- a/pytato/transform/__init__.py
+++ b/pytato/transform/__init__.py
@@ -46,7 +46,6 @@ from pytato.distributed.nodes import (
 from pytato.loopy import LoopyCall, LoopyCallResult
 from dataclasses import dataclass
 from pytato.tags import ImplStored
-from immutables import Map
 from pymbolic.mapper.optimize import optimize_mapper
 
 ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
@@ -85,7 +84,6 @@ Dict representation of DAGs
 ---------------------------
 
 .. autoclass:: UsersCollector
-.. autofunction:: reverse_graph
 .. autofunction:: tag_user_nodes
 .. autofunction:: rec_get_user_nodes
 
@@ -1468,29 +1466,6 @@ def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames,
 
 # {{{ operations on graphs in dict form
 
-def reverse_graph(graph: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]]
-                  ) -> Map[ArrayOrNames, FrozenSet[ArrayOrNames]]:
-    """
-    Reverses a graph.
-
-    :param graph: A :class:`dict` representation of a directed graph, mapping each
-        node to other nodes to which it is connected by edges. A possible
-        use case for this function is the graph in
-        :attr:`UsersCollector.node_to_users`.
-    :returns: A :class:`immutables.Map` representing *graph* with edges reversed.
-    """
-    result: Dict[ArrayOrNames, Set[ArrayOrNames]] = {}
-
-    for node_key, edges in graph.items():
-        # Make sure every node is in the result even if it has no users
-        result.setdefault(node_key, set())
-
-        for other_node_key in edges:
-            result.setdefault(other_node_key, set()).add(node_key)
-
-    return Map({k: frozenset(v) for k, v in result.items()})
-
-
 def _recursively_get_all_users(
         direct_users: Mapping[ArrayOrNames, Set[ArrayOrNames]],
         node: ArrayOrNames) -> FrozenSet[ArrayOrNames]:
diff --git a/setup.py b/setup.py
index ed32405e697839a4c451a62e5b0aba4a6572759a..e00ab52548e0fb0c3ddb775a913b2f01cba12e21 100644
--- a/setup.py
+++ b/setup.py
@@ -35,7 +35,7 @@ setup(
     python_requires="~=3.8",
     install_requires=[
         "loopy>=2020.2",
-        "pytools>=2021.1",
+        "pytools>=2022.1.13",
         "immutables",
         "attrs",
         "bidict",
diff --git a/test/test_pytato.py b/test/test_pytato.py
old mode 100755
new mode 100644
index adfa48242e2d4660b87a95c35907a1df5d589e03..f94aee98525dbd34b1950d94985ddbdd4618c66a
--- a/test/test_pytato.py
+++ b/test/test_pytato.py
@@ -324,7 +324,10 @@ def test_toposortmapper():
 
 def test_userscollector():
     from testlib import RandomDAGContext, make_random_dag
-    from pytato.transform import UsersCollector, reverse_graph
+    from pytato.transform import UsersCollector
+    from pytato.analysis import get_nusers
+
+    from pytools.graph import reverse_graph
 
     # Check that nodes without users are correctly reversed
     array = pt.make_placeholder(name="array", shape=1, dtype=np.int64)
@@ -332,15 +335,22 @@ def test_userscollector():
 
     uc = UsersCollector()
     uc(y)
+
     rev_graph = reverse_graph(uc.node_to_users)
     rev_graph2 = reverse_graph(reverse_graph(rev_graph))
 
     assert dict(reverse_graph(rev_graph2)) == uc.node_to_users
 
+    assert len(uc.node_to_users) == 2
+    assert uc.node_to_users[y] == set()
+    assert uc.node_to_users[array].pop() == y
+    assert len(uc.node_to_users[array]) == 0
+
     # Test random DAGs
     axis_len = 5
 
     for i in range(100):
+        print(i)  # progress indicator
         rdagc = RandomDAGContext(np.random.default_rng(seed=i),
                 axis_len=axis_len, use_numpy=False)
 
@@ -351,9 +361,14 @@ def test_userscollector():
 
         rev_graph = reverse_graph(uc.node_to_users)
         rev_graph2 = reverse_graph(reverse_graph(rev_graph))
-
         assert rev_graph2 == rev_graph
 
+        nuc = get_nusers(dag)
+
+        assert len(uc.node_to_users) == len(nuc)+1
+        assert uc.node_to_users[dag] == set()
+        assert nuc[dag] == 0
+
 
 def test_asciidag():
     pytest.importorskip("asciidag")