From cc4b03918d9893583da2b41e682a811fbcf20ff0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 24 Mar 2022 17:04:32 -0500 Subject: [PATCH] test pickling support for DAGs --- test/test_pytato.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 5f602a2..a674098 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -790,6 +790,54 @@ def test_einsum_dot_axes_has_correct_dim(): assert len(einsum.axes) == einsum.ndim +def test_pickling_and_unpickling_is_equal(): + from testlib import RandomDAGContext, make_random_dag + import pickle + from pytools import UniqueNameGenerator + axis_len = 5 + + for i in range(50): + print(i) # progress indicator + + seed = 120 + i + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + dag = pt.make_dict_of_named_arrays({"out": make_random_dag(rdagc_pt)}) + + # {{{ convert data-wrappers to placeholders + + vng = UniqueNameGenerator() + + def make_dws_placeholder(expr): + if isinstance(expr, pt.DataWrapper): + return pt.make_placeholder(vng("_pt_ph"), + expr.shape, expr.dtype) + else: + return expr + + dag = pt.transform.map_and_copy(dag, make_dws_placeholder) + + # }}} + + assert pickle.loads(pickle.dumps(dag)) == dag + + # {{{ adds an example which guarantees NaN in expression tree + + # pytato<=d015f914 used IEEE-representation of NaN in its expression graphs + # and since NaN != NaN the following assertions would fail. + + x = pt.make_placeholder("x", shape=(10, 4), dtype="float64") + expr = pt.maximum(2*x, 3*x) + + assert pickle.loads(pickle.dumps(expr)) == expr + + expr = pt.full((10, 4), np.nan) + assert pickle.loads(pickle.dumps(expr)) == expr + + # }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab