From c0188b58766d954a7cb159fa6d8831b98bc7c8bd Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 14 Aug 2018 13:44:13 -0500 Subject: [PATCH] Fix bounds to be inclusive in test_random_float_in_range --- test/test_array.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/test_array.py b/test/test_array.py index 5f54cfa3..3e74bcf0 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -581,7 +581,7 @@ def test_bitwise(ctx_factory): @pytest.mark.parametrize("rng_class", [RanluxGenerator, PhiloxGenerator, ThreefryGenerator]) -@pytest.mark.parametrize("ary_size", [300, 301, 302, 303, 10007]) +@pytest.mark.parametrize("ary_size", [300, 301, 302, 303, 10007, 1000000]) def test_random_float_in_range(ctx_factory, rng_class, ary_size, plot_hist=False): context = ctx_factory() queue = cl.CommandQueue(context) @@ -606,16 +606,22 @@ def test_random_float_in_range(ctx_factory, rng_class, ary_size, plot_hist=False pt.hist(ran.get(), 30) pt.show() - assert (0 < ran.get()).all() - assert (ran.get() < 1).all() + assert (0 <= ran.get()).all() + assert (ran.get() <= 1).all() if rng_class is RanluxGenerator: gen.synchronize(queue) ran = cl_array.zeros(queue, ary_size, dtype) gen.fill_uniform(ran, a=4, b=7) - assert (4 < ran.get()).all() - assert (ran.get() < 7).all() + ran_host = ran.get() + + for cond in [4 <= ran_host, ran_host <= 7]: + good = cond.all() + if not good: + print(np.where(~cond)) + print(ran_host[~cond]) + assert good ran = gen.normal(queue, ary_size, dtype, mu=10, sigma=3) -- GitLab