diff --git a/test/test_array.py b/test/test_array.py index 8c4ae9b7aacb7501ad5d871c53335c0645e2c1bd..ff2b94ca1413a27282876014023aa0107844461d 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -413,7 +413,7 @@ def test_divide_array(ctx_factory): # {{{ RNG -def test_random(ctx_factory): +def test_random_float_in_range(ctx_factory): context = ctx_factory() queue = cl.CommandQueue(context) @@ -442,14 +442,25 @@ def test_random(ctx_factory): ran = gen.normal(queue, (10007,), dtype, mu=4, sigma=3) - dtypes = [np.int32, np.int64] - for dtype in dtypes: - ran = gen.uniform(queue, (10000007,), dtype, a=200, b=300) - assert (200 <= ran.get()).all() - assert (ran.get() < 300).all() - #from matplotlib import pyplot as pt - #pt.hist(ran.get()) - #pt.show() + +@pytest.mark.parametrize("dtype", [np.int32, np.int64]) +def test_random_int_in_range(ctx_factory, dtype): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.clrandom import RanluxGenerator + gen = RanluxGenerator(queue, 5120) + + if (dtype == np.int64 + and context.devices[0].platform.vendor.startswith("Advanced Micro")): + pytest.xfail("AMD miscompiles 64-bit RNG math") + + ran = gen.uniform(queue, (10000007,), dtype, a=200, b=300) + assert (200 <= ran.get()).all() + assert (ran.get() < 300).all() + #from matplotlib import pyplot as pt + #pt.hist(ran.get()) + #pt.show() # }}}