diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py
index 426a8cff511b89765cdb6662fdec8bad38b522b4..433fda1fc9090ed24092c0fa41e7e508e94e8283 100644
--- a/pytato/distributed/partition.py
+++ b/pytato/distributed/partition.py
@@ -65,7 +65,7 @@ from functools import reduce
 import collections
 from typing import (
         Iterator, Iterable, Sequence, Any, Mapping, FrozenSet, Set, Dict, cast,
-        List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional)
+        List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional, Tuple)
 
 import attrs
 from immutabledict import immutabledict
@@ -509,48 +509,96 @@ class _LocalSendRecvDepGatherer(
 # }}}
 
 
-# {{{ _schedule_comm_batches
+TaskType = TypeVar("TaskType")
 
-def _schedule_comm_batches(
-        comm_ids_to_needed_comm_ids: CommunicationDepGraph
-        ) -> Sequence[AbstractSet[CommunicationOpIdentifier]]:
-    """For each :class:`CommunicationOpIdentifier`, determine the
+
+# {{{ _schedule_task_batches (and related)
+
+def _schedule_task_batches(
+        task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \
+        -> Sequence[AbstractSet[TaskType]]:
+    """For each :type:`TaskType`, determine the
     'round'/'batch' during which it will be performed. A 'batch'
-    of communication consists of sends and receives. Computation
-    occurs between batches. (So, from the perspective of the
-    :class:`DistributedGraphPartition`, communication batches
-    sit *between* parts.)
+    of tasks consists of tasks which do not depend on each other.
+    A task may only be in a batch if all of its dependents have already been
+    completed.
     """
-    # FIXME: I'm an O(n^2) algorithm.
+    return _schedule_task_batches_counted(task_ids_to_needed_task_ids)[0]
+# }}}
+
+
+# {{{ _schedule_task_batches_counted
+
+def _schedule_task_batches_counted(
+        task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \
+        -> Tuple[Sequence[AbstractSet[TaskType]], int]:
+    """
+    Static type checkers need the functions to return the same type regardless
+    of the input. The testing code needs to know about the number of tasks visited
+    during the scheduling algorithm's execution. However, nontesting code does not.
+    """
+    task_to_dep_level, visits_in_depend = \
+            _calculate_dependency_levels(task_ids_to_needed_task_ids)
+    nlevels = 1 + max(task_to_dep_level.values(), default=-1)
+    task_batches: Sequence[Set[TaskType]] = [set() for _ in range(nlevels)]
+
+    for task_id, dep_level in task_to_dep_level.items():
+        task_batches[dep_level].add(task_id)
 
-    comm_batches: List[AbstractSet[CommunicationOpIdentifier]] = []
+    return task_batches, visits_in_depend + len(task_to_dep_level.keys())
 
-    scheduled_comm_ids: Set[CommunicationOpIdentifier] = set()
-    comms_to_schedule = set(comm_ids_to_needed_comm_ids)
+# }}}
+
+
+# {{{ _calculate_dependency_levels
 
-    all_comm_ids = frozenset(comm_ids_to_needed_comm_ids)
+def _calculate_dependency_levels(
+        task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]
+        ) -> Tuple[Mapping[TaskType, int], int]:
+    """Calculate the minimum dependendency level needed before a task of
+    type TaskType can be scheduled. We assume that any number of tasks
+    can be scheduled at the same time. To attain complexity linear in the
+    number of nodes, we assume that each task has a constant number of direct
+    dependents.
+
+    The minimum dependency level for a task, i, is defined as
+    1 + the maximum dependency level for its children.
+    """
+    task_to_dep_level: Dict[TaskType, int] = {}
+    seen: set[TaskType] = set()
+    nodes_visited: int = 0
 
-    # FIXME In order for this to work, comm tags must be unique
-    while len(scheduled_comm_ids) < len(all_comm_ids):
-        comm_ids_this_batch = {
-                comm_id for comm_id in comms_to_schedule
-                if comm_ids_to_needed_comm_ids[comm_id] <= scheduled_comm_ids}
+    def _dependency_level_dfs(task_id: TaskType) -> int:
+        """Helper function to do depth first search on a graph."""
 
-        if not comm_ids_this_batch:
-            raise CycleError("cycle detected in communication graph")
+        if task_id in task_to_dep_level:
+            return task_to_dep_level[task_id]
 
-        scheduled_comm_ids.update(comm_ids_this_batch)
-        comms_to_schedule = comms_to_schedule - comm_ids_this_batch
+        # If node has been 'seen', but dep level is not yet known, that's a cycle.
+        if task_id in seen:
+            raise CycleError("Cycle detected in your input graph.")
+        seen.add(task_id)
 
-        comm_batches.append(comm_ids_this_batch)
+        nonlocal nodes_visited
+        nodes_visited += 1
 
-    return comm_batches
+        dep_level = 1 + max(
+                [_dependency_level_dfs(dep)
+                    for dep in task_ids_to_needed_task_ids[task_id]] or [-1])
+        task_to_dep_level[task_id] = dep_level
+        return dep_level
+
+    for task_id in task_ids_to_needed_task_ids:
+        _dependency_level_dfs(task_id)
+
+    return task_to_dep_level, nodes_visited
 
 # }}}
 
 
 # {{{  _MaterializedArrayCollector
 
+
 @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
 class _MaterializedArrayCollector(CachedWalkMapper):
     """
@@ -751,7 +799,7 @@ def find_distributed_partition(
         # The comm_batches correspond one-to-one to DistributedGraphParts
         # in the output.
         try:
-            comm_batches = _schedule_comm_batches(comm_ids_to_needed_comm_ids)
+            comm_batches = _schedule_task_batches(comm_ids_to_needed_comm_ids)
         except Exception as exc:
             mpi_communicator.bcast(exc)
             raise
@@ -771,7 +819,6 @@ def find_distributed_partition(
     # {{{ create (local) parts out of batch ids
 
     part_comm_ids: List[_PartCommIDs] = []
-
     if comm_batches:
         recv_ids: FrozenSet[CommunicationOpIdentifier] = frozenset()
         for batch in comm_batches:
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 8d7cd50dd5d5e5f6e1c13b2f01f7a3d659335990..f7a8e5b4c37cfb17ee3d7db7f65fbc6c30431116 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -118,6 +118,124 @@ def _do_test_distributed_execution_basic(ctx_factory):
 # }}}
 
 
+# {{{ Scheduler Algorithm update tests.
+
+def test_distributed_scheduler_counts():
+    """ Test that the scheduling algorithm runs in `O(n)` time when
+    operating on a DAG which is just a stick with the dependencies
+    implied and not directly listed.
+    """
+    from pytato.distributed.partition import _schedule_task_batches_counted
+    sizes = np.logspace(0, 6, 10, dtype=int)
+    count_list = np.zeros(len(sizes))
+    for i, tree_size in enumerate(sizes):
+        needed_ids = {i: set() for i in range(int(tree_size))}
+        for key in needed_ids.keys():
+            needed_ids[key] = {key-1} if key > 0 else set()
+        _, count_list[i] = _schedule_task_batches_counted(needed_ids)
+
+    # Now to do the fitting.
+    coefficients = np.polyfit(sizes, count_list, 4)
+    import numpy.linalg as la
+    nonlinear_norm_frac = la.norm(coefficients[:-2], 2)/la.norm(coefficients, 2)
+    assert nonlinear_norm_frac < 0.0001
+
+# }}}
+
+
+# {{{  test_distributed_scheduler_has_minimum_num_of_levels
+
+def test_distributed_scheduler_returns_minimum_num_of_levels():
+    from pytato.distributed.partition import _schedule_task_batches_counted
+    max_size = 10
+    needed_ids = {j: set() for j in range(max_size)}
+    for i in range(1, max_size-1):
+        needed_ids[i].add(i-1)
+
+    batches, _ = _schedule_task_batches_counted(needed_ids)
+    # The last task has no dependences listed so it can be placed anywhere.
+    assert len(batches) == (max_size - 1)
+
+# }}}
+
+
+# {{{  test_distributed_scheduling_alg_can_find_cycle
+
+def test_distributed_scheduling_alg_can_find_cycle():
+    from pytato.distributed.partition import _schedule_task_batches_counted
+    sizes = 100
+    my_graph = {i: {i-1} for i in range(int(sizes))}
+    my_graph[0] = {}
+    my_graph[60].add(95)  # Here is the cycle. 60 - 95 -94 - 93 ... - 60
+    with pytest.raises(CycleError):
+        _schedule_task_batches_counted(my_graph)
+
+# }}}
+
+
+# {{{ test scheduling based upon a tree with dependents listed out.
+
+def test_distributed_scheduling_o_n_direct_dependents():
+    """ Check that the temporal complexity of the scheduling algorithm
+    in the case that there are `O(n)` direct dependents for each task
+    is not cubic.
+    """
+    from pytato.distributed.partition import _schedule_task_batches_counted
+    sizes = np.logspace(0, 4, 10, dtype=int)
+    count_list = np.zeros(len(sizes))
+    for i, tree_size in enumerate(sizes):
+        needed_ids = {i: set() for i in range(int(tree_size))}
+        for key in needed_ids.keys():
+            for j in range(key):
+                needed_ids[key].add(j)
+        _, count_list[i] = _schedule_task_batches_counted(needed_ids)
+
+    # Now to do the fitting.
+    coefficients = np.polyfit(sizes, count_list, 4)
+    import numpy.linalg as la
+    # We are expecting less then cubic scaling.
+    nonquadratic_norm_frac = la.norm(coefficients[:-3], 2)/la.norm(coefficients, 2)
+    assert nonquadratic_norm_frac < 0.0001
+
+# }}}
+
+
+# {{{ test scheduling constant branching tree
+
+def test_distributed_scheduling_constant_look_back_tree():
+    """Test that the scheduling algorithm scales in linear time if the input DAG
+    is a constant look back tree. This tree has a single root and then 5 tendrils
+    off of this root. Along the tendril each node has a direct dependence on the
+    previous one in the tendril but no other direct dependencies. This is intended
+    to confirm that the scheduling algorithm utilizing the minimum number of batch
+    levels possible.
+    """
+    from pytato.distributed.partition import _schedule_task_batches_counted
+    import math
+    sizes = np.logspace(0, 6, 10, dtype=int)
+    count_list = np.zeros(len(sizes))
+    branching_factor = 5
+    for i, tree_size in enumerate(sizes):
+        needed_ids = {j: set() for j in range(int(tree_size))}
+        for j in range(1, int(tree_size)):
+            if j < branching_factor:
+                needed_ids[j+1] = {0}
+            else:
+                needed_ids[j] = {j - branching_factor}
+        batches, count_list[i] = _schedule_task_batches_counted(needed_ids)
+
+        # Test that the number of batches is the expected minimum number.
+        assert len(batches) == math.ceil((tree_size - 1) / branching_factor) + 1
+
+    # Now to do the fitting.
+    coefficients = np.polyfit(sizes, count_list, 4)
+    import numpy.linalg as la
+    nonlinear_norm_frac = la.norm(coefficients[:-2], 2)/la.norm(coefficients, 2)
+    assert nonlinear_norm_frac < 0.0001
+
+# }}}
+
+
 # {{{ test based on random dag
 
 def test_distributed_execution_random_dag():