From c6fee290a4acaaeba9ad7589f6fa64920594c3ec Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 19 Jul 2021 20:56:25 -0500 Subject: [PATCH] adds a test for checking scalar operations --- test/test_codegen.py | 16 +++++++++++++++- test/test_pytato.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/test_codegen.py b/test/test_codegen.py index 44ddbe4..3cca1d0 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -979,6 +979,20 @@ def test_eye(ctx_factory, n, m, k): np.testing.assert_allclose(out.get(), np_eye) +@pytest.mark.parametrize("which,num_args", ([("maximum", 2), + ("minimum", 2), + ])) +def test_pt_ops_on_scalar_args_computed_eagerly(ctx_factory, which, num_args): + from numpy.random import default_rng + rng = default_rng() + args = [rng.random() for _ in range(num_args)] + + pt_func = getattr(pt, which) + np_func = getattr(np, which) + + np.testing.assert_allclose(pt_func(*args), np_func(*args)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) @@ -986,4 +1000,4 @@ if __name__ == "__main__": from pytest import main main([__file__]) -# vim: filetype=pyopencl:fdm=marker +# vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 1d8c6c0..b77c089 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -273,4 +273,4 @@ if __name__ == "__main__": from pytest import main main([__file__]) -# vim: filetype=pyopencl:fdm=marker +# vim: fdm=marker -- GitLab