From fa01a15c6d6fe01463ee44a932a559a2734f0db9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 21 Mar 2016 17:27:14 -0500 Subject: [PATCH] Support 64-bit integers in the RNG --- pyopencl/clrandom.py | 11 +++++++++++ test/test_array.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py index 5b73e6a0..db07d9b3 100644 --- a/pyopencl/clrandom.py +++ b/pyopencl/clrandom.py @@ -223,6 +223,17 @@ class RanluxGenerator(object): rng_expr = ("(shift " "+ convert_int4((float) scale * gen) " "+ convert_int4((float) (scale / (1<<24)) * gen))") + + elif dtype == np.int64: + assert distribution == "uniform" + bits = 64 + c_type = "long" + rng_expr = ("(shift " + "+ convert_long4((float) scale * gen) " + "+ convert_long4((float) (scale / (1<<24)) * gen)" + "+ convert_long4((float) (scale / (1<<48)) * gen)" + ")") + else: raise TypeError("unsupported RNG data type '%s'" % dtype) diff --git a/test/test_array.py b/test/test_array.py index be926d7e..8c4ae9b7 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -442,7 +442,7 @@ def test_random(ctx_factory): ran = gen.normal(queue, (10007,), dtype, mu=4, sigma=3) - dtypes = [np.int32] + dtypes = [np.int32, np.int64] for dtype in dtypes: ran = gen.uniform(queue, (10000007,), dtype, a=200, b=300) assert (200 <= ran.get()).all() -- GitLab