From 14db4e921b9cf7f375bef72f6ae40ab76cfcf0ae Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 21 Jul 2021 18:47:48 -0500 Subject: [PATCH] tests pt.(v?)dot --- test/test_codegen.py | 48 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/test_codegen.py b/test/test_codegen.py index 3cca1d0..202958a 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -993,6 +993,54 @@ def test_pt_ops_on_scalar_args_computed_eagerly(ctx_factory, which, num_args): np.testing.assert_allclose(pt_func(*args), np_func(*args)) +@pytest.mark.parametrize("a_shape,b_shape", ([((10,), (10,)), + ((10, 4), (4, 10)), + ((10, 2, 2), (2,)), + ((10, 5, 2, 7), (3, 7, 4))])) +@pytest.mark.parametrize("a_dtype", [np.float32, np.complex64]) +@pytest.mark.parametrize("b_dtype", [np.float32, np.complex64]) +def test_dot(ctx_factory, a_shape, b_shape, a_dtype, b_dtype): + from numpy.random import default_rng + rng = default_rng() + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + a_in = (rng.random(a_shape) + 1j * rng.random(a_shape)).astype(a_dtype) + b_in = (rng.random(b_shape) + 1j * rng.random(b_shape)).astype(b_dtype) + a = pt.make_data_wrapper(a_in) + b = pt.make_data_wrapper(b_in) + + np_result = np.dot(a_in, b_in) + _, (pt_result,) = pt.generate_loopy(pt.dot(a, b))(cq) + + assert pt_result.shape == np_result.shape + assert pt_result.dtype == np_result.dtype + np.testing.assert_allclose(np_result, pt_result, rtol=1e-6) + + +@pytest.mark.parametrize("a_shape,b_shape", ([((10,), (10,)), + ((10, 4), (4, 10))])) +@pytest.mark.parametrize("a_dtype", [np.float32, np.complex64]) +@pytest.mark.parametrize("b_dtype", [np.float32, np.complex64]) +def test_vdot(ctx_factory, a_shape, b_shape, a_dtype, b_dtype): + from numpy.random import default_rng + rng = default_rng() + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + a_in = (rng.random(a_shape) + 1j * rng.random(a_shape)).astype(a_dtype) + b_in = (rng.random(b_shape) + 1j * rng.random(b_shape)).astype(b_dtype) + a = pt.make_data_wrapper(a_in) + b = pt.make_data_wrapper(b_in) + + np_result = np.vdot(a_in, b_in) + _, (pt_result,) = pt.generate_loopy(pt.vdot(a, b))(cq) + + assert pt_result.shape == np_result.shape + assert pt_result.dtype == np_result.dtype + np.testing.assert_allclose(np_result, pt_result, rtol=1e-6) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab