diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py index 5b73e6a039a42d613eb6502eaacc7250e5e47873..db07d9b3cf82da2401cf7cadc713764ff1ab9305 100644 --- a/pyopencl/clrandom.py +++ b/pyopencl/clrandom.py @@ -223,6 +223,17 @@ class RanluxGenerator(object): rng_expr = ("(shift " "+ convert_int4((float) scale * gen) " "+ convert_int4((float) (scale / (1<<24)) * gen))") + + elif dtype == np.int64: + assert distribution == "uniform" + bits = 64 + c_type = "long" + rng_expr = ("(shift " + "+ convert_long4((float) scale * gen) " + "+ convert_long4((float) (scale / (1<<24)) * gen)" + "+ convert_long4((float) (scale / (1<<48)) * gen)" + ")") + else: raise TypeError("unsupported RNG data type '%s'" % dtype) diff --git a/test/test_array.py b/test/test_array.py index be926d7e7b5f3e32a0fdc3e1772a41a38e4934b5..8c4ae9b7aacb7501ad5d871c53335c0645e2c1bd 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -442,7 +442,7 @@ def test_random(ctx_factory): ran = gen.normal(queue, (10007,), dtype, mu=4, sigma=3) - dtypes = [np.int32] + dtypes = [np.int32, np.int64] for dtype in dtypes: ran = gen.uniform(queue, (10000007,), dtype, a=200, b=300) assert (200 <= ran.get()).all()