From a6240512549f88190e7cb722053ae42b322d8ebb Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 23 May 2023 16:07:44 -0500 Subject: [PATCH] test visualizers Co-authored-by: Andreas Kloeckner --- .test-conda-env-py3.yml | 1 + test/test_pytato.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 48be7bc..9d7972c 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -14,3 +14,4 @@ dependencies: - mpi4py - jax - openmpi # Force using Open MPI since our pytest infrastructure needs it +- graphviz # for visualization tests diff --git a/test/test_pytato.py b/test/test_pytato.py index b16d56e..5d2343e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -35,6 +35,7 @@ import pytato as pt from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) +from testlib import RandomDAGContext, make_random_dag def test_matmul_input_validation(): @@ -1115,6 +1116,35 @@ def test_rewrite_einsums_with_no_broadcasts(): assert pt.analysis.is_einsum_similar_to_subscript(new_expr.args[2], "ij,ik->ijk") +def test_dot_visualizers(): + a = pt.make_placeholder("A", shape=(10, 4), dtype=np.float64) + x1 = pt.make_placeholder("x1", shape=4, dtype=np.float64) + x2 = pt.make_placeholder("x2", shape=4, dtype=np.float64) + + y = a @ (2*x1 + 3*x2) + + axis_len = 5 + + graphs = [y] + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=False) + graphs.append(make_random_dag(rdagc)) + + # {{{ ensure that the generated output is valid dot-lang + + # TODO: Verify the soundness of the generated svg file + + for graph in graphs: + # plot to .svg file to avoid dep on a webbrowser or X-window system + pt.show_dot_graph(graph, output_to="svg") + + pt.show_fancy_placeholder_data_flow(y, output_to="svg") + + # }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab