diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py index ba0d61b32a727bb0e1406ad06015c0f763ad2a54..decab716aa5d5f192fb3e4820e47fd491f1f13ce 100644 --- a/pyopencl/clrandom.py +++ b/pyopencl/clrandom.py @@ -287,10 +287,7 @@ class RanluxGenerator(object): while (idx + 4 < out_size) { output_vec_t ran = GET_RANDOM_NUM(RANLUX_FUNC(&ranluxclstate)); - output[idx] = ran.x; - output[idx + 1] = ran.y; - output[idx + 2] = ran.z; - output[idx + 3] = ran.w; + vstore4(ran, 0, &output[idx]); idx += 4*NUM_WORKITEMS; } @@ -598,10 +595,7 @@ class Random123GeneratorBase(object): while (idx + 4 < out_size) { output_vec_t ran = GET_RANDOM_NUM(gen_bits(&k, &c)); - output[idx] = ran.x; - output[idx + 1] = ran.y; - output[idx + 2] = ran.z; - output[idx + 3] = ran.w; + vstore4(ran, 0, &output[idx]); idx += 4*get_global_size(0); } diff --git a/test/test_clrandom.py b/test/test_clrandom.py index 2be2e46f8356ec7a59658df1c1edcfec2ebc41b9..2846e24c97e1c4c0356420dfa35a1278e615a77f 100644 --- a/test/test_clrandom.py +++ b/test/test_clrandom.py @@ -40,21 +40,13 @@ else: faulthandler.enable() -class RanluxGeneratorShim(object): - - def __init__(self, cl_ctx): - self.queue = cl.CommandQueue(cl_ctx) - self.gen = clrandom.RanluxGenerator(self.queue) - - def uniform(self, *args, **kwargs): - return self.gen.uniform(*args, **kwargs) - - def normal(self, *args, **kwargs): - return self.gen.normal(*args, **kwargs) +def make_ranlux_generator(cl_ctx): + queue = cl.CommandQueue(cl_ctx) + return clrandom.RanluxGenerator(queue) @pytest.mark.parametrize("rng_class", [ - RanluxGeneratorShim, + make_ranlux_generator, clrandom.PhiloxGenerator, clrandom.ThreefryGenerator]) @pytest.mark.parametrize("dtype", [