From 3b7109a69657619c147c18891b566e3a0881562f Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Tue, 16 Jan 2018 17:35:42 -0600
Subject: [PATCH] clrandom: Handle underaligned buffers (closes #6).

Also tests that the rng works on all supported data types.
---
 pyopencl/clrandom.py  | 14 +++++--
 test/test_clrandom.py | 87 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 97 insertions(+), 4 deletions(-)
 create mode 100644 test/test_clrandom.py

diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py
index 2ac54c36..ba0d61b3 100644
--- a/pyopencl/clrandom.py
+++ b/pyopencl/clrandom.py
@@ -286,8 +286,11 @@ class RanluxGenerator(object):
               unsigned long idx = get_global_id(0)*4;
               while (idx + 4 < out_size)
               {
-                  *(global output_vec_t *) (output + idx) =
-                      GET_RANDOM_NUM(RANLUX_FUNC(&ranluxclstate));
+                  output_vec_t ran = GET_RANDOM_NUM(RANLUX_FUNC(&ranluxclstate));
+                  output[idx] = ran.x;
+                  output[idx + 1] = ran.y;
+                  output[idx + 2] = ran.z;
+                  output[idx + 3] = ran.w;
                   idx += 4*NUM_WORKITEMS;
               }
 
@@ -594,8 +597,11 @@ class Random123GeneratorBase(object):
                 unsigned long idx = get_global_id(0)*4;
                 while (idx + 4 < out_size)
                 {
-                    *(global output_vec_t *) (output + idx) =
-                        GET_RANDOM_NUM(gen_bits(&k, &c));
+                    output_vec_t ran = GET_RANDOM_NUM(gen_bits(&k, &c));
+                    output[idx] = ran.x;
+                    output[idx + 1] = ran.y;
+                    output[idx + 2] = ran.z;
+                    output[idx + 3] = ran.w;
                     idx += 4*get_global_size(0);
                 }
 
diff --git a/test/test_clrandom.py b/test/test_clrandom.py
new file mode 100644
index 00000000..3187fbf5
--- /dev/null
+++ b/test/test_clrandom.py
@@ -0,0 +1,87 @@
+from __future__ import division, print_function, absolute_import
+
+__copyright__ = "Copyright (C) 2018 Matt Wala"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import numpy as np
+import pytest
+
+import pyopencl as cl
+import pyopencl.cltypes as cltypes
+import pyopencl.clrandom as clrandom
+from pyopencl.tools import (  # noqa
+        pytest_generate_tests_for_pyopencl
+        as pytest_generate_tests)
+
+try:
+    import faulthandler
+except ImportError:
+    pass
+else:
+    faulthandler.enable()
+
+
+class RanluxGeneratorShim(object):
+
+    def __init__(self, cl_ctx):
+        self.queue = cl.CommandQueue(cl_ctx)
+        self.gen = clrandom.RanluxGenerator(self.queue)
+
+    def uniform(self, *args, **kwargs):
+        return self.gen.uniform(*args, **kwargs)
+
+    def normal(self, *args, **kwargs):
+        return self.gen.normal(*args, **kwargs)
+
+
+@pytest.mark.parametrize("rng_class", [
+    RanluxGeneratorShim,
+    clrandom.PhiloxGenerator,
+    clrandom.ThreefryGenerator])
+@pytest.mark.parametrize("dtype", [
+    np.int32,
+    np.int64,
+    np.float32,
+    np.float64,
+    cltypes.float2,
+    cltypes.float3,
+    cltypes.float4])
+def test_clrandom_dtypes(ctx_factory, rng_class, dtype):
+    cl_ctx = ctx_factory()
+    rng = rng_class(cl_ctx)
+
+    size = 10
+
+    with cl.CommandQueue(cl_ctx) as queue:
+        rng.uniform(queue, size, dtype)
+        
+        if dtype not in (np.int32, np.int64):
+            rng.normal(queue, size, dtype)
+
+
+if __name__ == "__main__":
+    import sys
+    if len(sys.argv) > 1:
+        exec(sys.argv[1])
+    else:
+        import py.test
+        py.test.cmdline.main([__file__])
-- 
GitLab