Skip to content
Snippets Groups Projects
Commit 5d2d1ac8 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Revert clrandom to using None for workgroup size.

parent d7b2da05
Branches
Tags
No related merge requests found
......@@ -32,15 +32,19 @@ class RanluxGenerator(object):
self.max_work_items = max_work_items
src = """
%s
%(defines)s
#include <pyopencl-ranluxcl.cl>
kernel void init_ranlux(unsigned seeds, global ranluxcl_state_t *ranluxcltab)
{
ranluxcl_initialization(seeds, ranluxcltab);
if (get_global_id(0) < %(num_work_items)d)
ranluxcl_initialization(seeds, ranluxcltab);
}
""" % self.generate_settings_defines()
""" % {
"defines": self.generate_settings_defines(),
"num_work_items": num_work_items
}
prg = cl.Program(queue.context, src).build()
# {{{ compute work group size
......@@ -67,7 +71,7 @@ class RanluxGenerator(object):
self.state = cl_array.empty(queue, (num_work_items, 112), dtype=np.uint8)
self.state.fill(17)
prg.init_ranlux(queue, (num_work_items,), (self.wg_size,), np.uint32(seed),
prg.init_ranlux(queue, (num_work_items,), None, np.uint32(seed),
self.state.data)
def generate_settings_defines(self, include_double_pragma=True):
......@@ -132,6 +136,7 @@ class RanluxGenerator(object):
output_t scale,
output_t shift)
{
ranluxcl_state_t ranluxclstate;
ranluxcl_download_seed(&ranluxclstate, ranluxcltab);
......@@ -175,7 +180,7 @@ class RanluxGenerator(object):
queue = ary.queue
self.get_gen_kernel(ary.dtype, "")(queue,
(self.num_work_items,), (self.wg_size,),
(self.num_work_items,), None,
self.state.data, ary.data, ary.size,
b-a, a)
......@@ -193,7 +198,7 @@ class RanluxGenerator(object):
queue = ary.queue
self.get_gen_kernel(ary.dtype, "norm")(queue,
(self.num_work_items,), (self.wg_size,),
(self.num_work_items,), None,
self.state.data, ary.data, ary.size, sigma, mu)
def normal(self, *args, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment