From fa01a15c6d6fe01463ee44a932a559a2734f0db9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 21 Mar 2016 17:27:14 -0500
Subject: [PATCH] Support 64-bit integers in the RNG

---
 pyopencl/clrandom.py | 11 +++++++++++
 test/test_array.py   |  2 +-
 2 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py
index 5b73e6a0..db07d9b3 100644
--- a/pyopencl/clrandom.py
+++ b/pyopencl/clrandom.py
@@ -223,6 +223,17 @@ class RanluxGenerator(object):
             rng_expr = ("(shift "
                     "+ convert_int4((float) scale * gen) "
                     "+ convert_int4((float) (scale / (1<<24)) * gen))")
+
+        elif dtype == np.int64:
+            assert distribution == "uniform"
+            bits = 64
+            c_type = "long"
+            rng_expr = ("(shift "
+                    "+ convert_long4((float) scale * gen) "
+                    "+ convert_long4((float) (scale / (1<<24)) * gen)"
+                    "+ convert_long4((float) (scale / (1<<48)) * gen)"
+                    ")")
+
         else:
             raise TypeError("unsupported RNG data type '%s'" % dtype)
 
diff --git a/test/test_array.py b/test/test_array.py
index be926d7e..8c4ae9b7 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -442,7 +442,7 @@ def test_random(ctx_factory):
 
             ran = gen.normal(queue, (10007,), dtype, mu=4, sigma=3)
 
-    dtypes = [np.int32]
+    dtypes = [np.int32, np.int64]
     for dtype in dtypes:
         ran = gen.uniform(queue, (10000007,), dtype, a=200, b=300)
         assert (200 <= ran.get()).all()
-- 
GitLab