diff --git a/test/test_array.py b/test/test_array.py index 5f54cfa3ed636f65cf9859e4ac7e347ae02f64b6..3e74bcf0e2bc3c5c56ebfbb971164d89fcc49a35 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -581,7 +581,7 @@ def test_bitwise(ctx_factory): @pytest.mark.parametrize("rng_class", [RanluxGenerator, PhiloxGenerator, ThreefryGenerator]) -@pytest.mark.parametrize("ary_size", [300, 301, 302, 303, 10007]) +@pytest.mark.parametrize("ary_size", [300, 301, 302, 303, 10007, 1000000]) def test_random_float_in_range(ctx_factory, rng_class, ary_size, plot_hist=False): context = ctx_factory() queue = cl.CommandQueue(context) @@ -606,16 +606,22 @@ def test_random_float_in_range(ctx_factory, rng_class, ary_size, plot_hist=False pt.hist(ran.get(), 30) pt.show() - assert (0 < ran.get()).all() - assert (ran.get() < 1).all() + assert (0 <= ran.get()).all() + assert (ran.get() <= 1).all() if rng_class is RanluxGenerator: gen.synchronize(queue) ran = cl_array.zeros(queue, ary_size, dtype) gen.fill_uniform(ran, a=4, b=7) - assert (4 < ran.get()).all() - assert (ran.get() < 7).all() + ran_host = ran.get() + + for cond in [4 <= ran_host, ran_host <= 7]: + good = cond.all() + if not good: + print(np.where(~cond)) + print(ran_host[~cond]) + assert good ran = gen.normal(queue, ary_size, dtype, mu=10, sigma=3)