diff --git a/test/test_array.py b/test/test_array.py index c0015d8779120a4bba35e098c94632b495bf4e04..7722fde5cb441194acd6c5d4ece6dd0495dbd836 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -235,18 +235,21 @@ def test_random(ctx_factory): gen = RanluxGenerator(queue, 5120) - for dtype in dtypes: - ran = gen.uniform(queue, (10007,), dtype) - assert (0 < ran.get()).all() - assert (ran.get() < 1).all() + for ary_size in [300, 301, 302, 303, 10007]: + for dtype in dtypes: + ran = cl_array.zeros(queue, ary_size, dtype) + gen.fill_uniform(ran) + assert (0 < ran.get()).all() + assert (ran.get() < 1).all() - gen.synchronize(queue) + gen.synchronize(queue) - ran = gen.uniform(queue, (10007,), dtype, a=4, b=7) - assert (4 < ran.get()).all() - assert (ran.get() < 7).all() + ran = cl_array.zeros(queue, ary_size, dtype) + gen.fill_uniform(ran, a=4, b=7) + assert (4 < ran.get()).all() + assert (ran.get() < 7).all() - ran = gen.normal(queue, (10007,), dtype, mu=4, sigma=3) + ran = gen.normal(queue, (10007,), dtype, mu=4, sigma=3) dtypes = [np.int32] for dtype in dtypes: