diff --git a/pytato/transform.py b/pytato/transform.py index 4a838b5a2ffb867ba99d4ea913f3709110c003f2..ce7552671cedc70e792396cb1a8e4997d4d00275 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -559,6 +559,20 @@ class CachedWalkMapper(WalkMapper): # }}} +# {{{ TopoSortMapper + +class TopoSortMapper(CachedWalkMapper): + """A mapper that creates a list of nodes in topological order.""" + + def __init__(self) -> None: + super().__init__() + self.topological_order = [] + + def post_visit(self, expr: Any) -> bool: + self.topological_order.append(expr) + +# }}} + # {{{ mapper frontends def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, diff --git a/test/test_pytato.py b/test/test_pytato.py index 5c3fdb519778891a2c5b1e6596c8b58570f6ecd1..78a8b2068728e605f1efaadebc714884eefb94a8 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -278,6 +278,27 @@ def test_dict_of_named_arrays_comparison(): assert dict1 != dict4 +def test_toposortmapper(): + n = pt.make_size_param("n") + array = pt.make_placeholder(name="array", shape=n, dtype=np.float64) + stack = pt.stack([array, 2*array, array + 6]) + y = stack @ stack.T + + tm = pt.transform.TopoSortMapper() + tm(y) + + from pytato.array import (AxisPermutation, IndexLambda, MatrixProduct, + Placeholder, SizeParam, Stack) + + assert isinstance(tm.topological_order[0], SizeParam) + assert isinstance(tm.topological_order[1], Placeholder) + assert isinstance(tm.topological_order[2], IndexLambda) + assert isinstance(tm.topological_order[3], IndexLambda) + assert isinstance(tm.topological_order[4], Stack) + assert isinstance(tm.topological_order[5], AxisPermutation) + assert isinstance(tm.topological_order[6], MatrixProduct) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])