From 1f38f45a3998479324a911274d1a896c6448b0f3 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 14 Mar 2023 13:25:55 -0500 Subject: [PATCH] Test NamedArrays for JAXTarget --- test/test_jax.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_jax.py b/test/test_jax.py index 0dd8f6e..dc93857 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -127,3 +127,14 @@ def test_placeholders_in_jax(jit): np_out = img_in * scl_in np.testing.assert_allclose(pt_out, np_out, rtol=1e-6) + + +@pytest.mark.parametrize("jit", ([False, True])) +def test_exprs_with_named_array(jit): + # pytato.git <= cf3673a would fail this regression + x_in = np.random.rand(10, 4) + x = pt.make_data_wrapper(x_in) + y1y2 = pt.make_dict_of_named_arrays({"y1": 2*x, "y2": 3*x}) + res = 21*y1y2["y1"] + out = pt.generate_jax(res)() + np.testing.assert_allclose(out, 42*x_in) -- GitLab