From 3b7109a69657619c147c18891b566e3a0881562f Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 16 Jan 2018 17:35:42 -0600 Subject: [PATCH 1/2] clrandom: Handle underaligned buffers (closes #6). Also tests that the rng works on all supported data types. --- pyopencl/clrandom.py | 14 +++++-- test/test_clrandom.py | 87 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 test/test_clrandom.py diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py index 2ac54c36..ba0d61b3 100644 --- a/pyopencl/clrandom.py +++ b/pyopencl/clrandom.py @@ -286,8 +286,11 @@ class RanluxGenerator(object): unsigned long idx = get_global_id(0)*4; while (idx + 4 < out_size) { - *(global output_vec_t *) (output + idx) = - GET_RANDOM_NUM(RANLUX_FUNC(&ranluxclstate)); + output_vec_t ran = GET_RANDOM_NUM(RANLUX_FUNC(&ranluxclstate)); + output[idx] = ran.x; + output[idx + 1] = ran.y; + output[idx + 2] = ran.z; + output[idx + 3] = ran.w; idx += 4*NUM_WORKITEMS; } @@ -594,8 +597,11 @@ class Random123GeneratorBase(object): unsigned long idx = get_global_id(0)*4; while (idx + 4 < out_size) { - *(global output_vec_t *) (output + idx) = - GET_RANDOM_NUM(gen_bits(&k, &c)); + output_vec_t ran = GET_RANDOM_NUM(gen_bits(&k, &c)); + output[idx] = ran.x; + output[idx + 1] = ran.y; + output[idx + 2] = ran.z; + output[idx + 3] = ran.w; idx += 4*get_global_size(0); } diff --git a/test/test_clrandom.py b/test/test_clrandom.py new file mode 100644 index 00000000..3187fbf5 --- /dev/null +++ b/test/test_clrandom.py @@ -0,0 +1,87 @@ +from __future__ import division, print_function, absolute_import + +__copyright__ = "Copyright (C) 2018 Matt Wala" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np +import pytest + +import pyopencl as cl +import pyopencl.cltypes as cltypes +import pyopencl.clrandom as clrandom +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl + as pytest_generate_tests) + +try: + import faulthandler +except ImportError: + pass +else: + faulthandler.enable() + + +class RanluxGeneratorShim(object): + + def __init__(self, cl_ctx): + self.queue = cl.CommandQueue(cl_ctx) + self.gen = clrandom.RanluxGenerator(self.queue) + + def uniform(self, *args, **kwargs): + return self.gen.uniform(*args, **kwargs) + + def normal(self, *args, **kwargs): + return self.gen.normal(*args, **kwargs) + + +@pytest.mark.parametrize("rng_class", [ + RanluxGeneratorShim, + clrandom.PhiloxGenerator, + clrandom.ThreefryGenerator]) +@pytest.mark.parametrize("dtype", [ + np.int32, + np.int64, + np.float32, + np.float64, + cltypes.float2, + cltypes.float3, + cltypes.float4]) +def test_clrandom_dtypes(ctx_factory, rng_class, dtype): + cl_ctx = ctx_factory() + rng = rng_class(cl_ctx) + + size = 10 + + with cl.CommandQueue(cl_ctx) as queue: + rng.uniform(queue, size, dtype) + + if dtype not in (np.int32, np.int64): + rng.normal(queue, size, dtype) + + +if __name__ == "__main__": + import sys + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + import py.test + py.test.cmdline.main([__file__]) -- GitLab From 307bd61d1ff99ba9a02e8d61d811bec52c4d3dbb Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 16 Jan 2018 18:01:06 -0600 Subject: [PATCH 2/2] flake8 fix --- test/test_clrandom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_clrandom.py b/test/test_clrandom.py index 3187fbf5..2be2e46f 100644 --- a/test/test_clrandom.py +++ b/test/test_clrandom.py @@ -73,7 +73,7 @@ def test_clrandom_dtypes(ctx_factory, rng_class, dtype): with cl.CommandQueue(cl_ctx) as queue: rng.uniform(queue, size, dtype) - + if dtype not in (np.int32, np.int64): rng.normal(queue, size, dtype) -- GitLab