diff --git a/test/test_jax.py b/test/test_jax.py index 0dd8f6eb4748e40af80740de602baabaafc7f5c5..dc93857fadf4ea1103a705800a8cf93299c61f1d 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -127,3 +127,14 @@ def test_placeholders_in_jax(jit): np_out = img_in * scl_in np.testing.assert_allclose(pt_out, np_out, rtol=1e-6) + + +@pytest.mark.parametrize("jit", ([False, True])) +def test_exprs_with_named_array(jit): + # pytato.git <= cf3673a would fail this regression + x_in = np.random.rand(10, 4) + x = pt.make_data_wrapper(x_in) + y1y2 = pt.make_dict_of_named_arrays({"y1": 2*x, "y2": 3*x}) + res = 21*y1y2["y1"] + out = pt.generate_jax(res)() + np.testing.assert_allclose(out, 42*x_in)