diff --git a/test/test_codegen.py b/test/test_codegen.py index 44ddbe48bb3fb76fca8fde5a0fe893c2dcc840a8..3cca1d0d63641a153963d635abf9cf9bd84ec47c 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -979,6 +979,20 @@ def test_eye(ctx_factory, n, m, k): np.testing.assert_allclose(out.get(), np_eye) +@pytest.mark.parametrize("which,num_args", ([("maximum", 2), + ("minimum", 2), + ])) +def test_pt_ops_on_scalar_args_computed_eagerly(ctx_factory, which, num_args): + from numpy.random import default_rng + rng = default_rng() + args = [rng.random() for _ in range(num_args)] + + pt_func = getattr(pt, which) + np_func = getattr(np, which) + + np.testing.assert_allclose(pt_func(*args), np_func(*args)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) @@ -986,4 +1000,4 @@ if __name__ == "__main__": from pytest import main main([__file__]) -# vim: filetype=pyopencl:fdm=marker +# vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 1d8c6c03d642073604f245ce949b5abff12cd6df..b77c08913881baa9031cb162ac30dd6b160b943e 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -273,4 +273,4 @@ if __name__ == "__main__": from pytest import main main([__file__]) -# vim: filetype=pyopencl:fdm=marker +# vim: fdm=marker