diff --git a/test/test_codegen.py b/test/test_codegen.py index 6ca17b4f08708e4bf4928252ac0d277ce21954d6..90cb68b52436d24a0f82ff582fb1ff28336dffb8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1034,6 +1034,27 @@ def test_vdot(ctx_factory, a_shape, b_shape, a_dtype, b_dtype): np.testing.assert_allclose(np_result, pt_result, rtol=1e-6) +def test_reduction_adds_deps(ctx_factory): + from numpy.random import default_rng + + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + rng = default_rng() + x_in = rng.random(10) + x = pt.make_data_wrapper(x_in) + y = 2*x + z = pt.sum(y) + pt_prg = pt.generate_loopy({"y": y, "z": z}) + + assert ("y_store" + in pt_prg.program.default_entrypoint.id_to_insn["z_store"].depends_on) + + _, out_dict = pt_prg(queue) + np.testing.assert_allclose(np.sum(2*x_in), + out_dict["z"]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])