diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 48be7bc1916587eb3967a71f97d2d04149f1d275..9d7972cef3846479840164a7aa96cf783ce0f772 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -14,3 +14,4 @@ dependencies: - mpi4py - jax - openmpi # Force using Open MPI since our pytest infrastructure needs it +- graphviz # for visualization tests diff --git a/test/test_pytato.py b/test/test_pytato.py index b16d56e0f14ece5db735568c9e8df046418451a4..5d2343ee9543f68c6cd58ac206b00d83ee6e2041 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -35,6 +35,7 @@ import pytato as pt from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) +from testlib import RandomDAGContext, make_random_dag def test_matmul_input_validation(): @@ -1115,6 +1116,35 @@ def test_rewrite_einsums_with_no_broadcasts(): assert pt.analysis.is_einsum_similar_to_subscript(new_expr.args[2], "ij,ik->ijk") +def test_dot_visualizers(): + a = pt.make_placeholder("A", shape=(10, 4), dtype=np.float64) + x1 = pt.make_placeholder("x1", shape=4, dtype=np.float64) + x2 = pt.make_placeholder("x2", shape=4, dtype=np.float64) + + y = a @ (2*x1 + 3*x2) + + axis_len = 5 + + graphs = [y] + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=False) + graphs.append(make_random_dag(rdagc)) + + # {{{ ensure that the generated output is valid dot-lang + + # TODO: Verify the soundness of the generated svg file + + for graph in graphs: + # plot to .svg file to avoid dep on a webbrowser or X-window system + pt.show_dot_graph(graph, output_to="svg") + + pt.show_fancy_placeholder_data_flow(y, output_to="svg") + + # }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])