From ecb964b25a5731ce9ae5a706e655d1cb2146d847 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 31 Dec 2020 12:42:17 -0600 Subject: [PATCH] Fix RNG event management --- pyopencl/clrandom.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py index 2471d8d2..91d447e0 100644 --- a/pyopencl/clrandom.py +++ b/pyopencl/clrandom.py @@ -334,10 +334,13 @@ class RanluxGenerator: queue = ary.queue knl, size_multiplier = self.get_gen_kernel(ary.dtype, "uniform") - return knl(queue, + evt = knl(queue, (self.num_work_items,), None, self.state.data, ary.data, ary.size*size_multiplier, - b-a, a) + b-a, a, wait_for=ary.event) + ary.add_event(evt) + self.state.add_event(evt) + return ary def uniform(self, *args, **kwargs): """Make a new empty array, apply :meth:`fill_uniform` to it. @@ -346,9 +349,7 @@ class RanluxGenerator: b = kwargs.pop("b", 1) result = cl_array.empty(*args, **kwargs) - - result.add_event( - self.fill_uniform(result, queue=result.queue, a=a, b=b)) + self.fill_uniform(result, queue=result.queue, a=a, b=b) return result def fill_normal(self, ary, mu=0, sigma=1, queue=None): @@ -364,9 +365,13 @@ class RanluxGenerator: queue = ary.queue knl, size_multiplier = self.get_gen_kernel(ary.dtype, "normal") - return knl(queue, + evt = knl(queue, (self.num_work_items,), self.wg_size, - self.state.data, ary.data, ary.size*size_multiplier, sigma, mu) + self.state.data, ary.data, ary.size*size_multiplier, sigma, mu, + wait_for=ary.events) + ary.add_event(evt) + self.state.add_event(evt) + return evt def normal(self, *args, **kwargs): """Make a new empty array, apply :meth:`fill_normal` to it. @@ -375,9 +380,7 @@ class RanluxGenerator: sigma = kwargs.pop("sigma", 1) result = cl_array.empty(*args, **kwargs) - - result.add_event( - self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma)) + self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma) return result @memoize_method @@ -663,6 +666,7 @@ class Random123GeneratorBase: gsize, lsize = splay(queue, ary.size) evt = knl(queue, gsize, lsize, *args) + ary.add_event(evt) self.counter[0] += n * counter_multiplier c1_incr, self.counter[0] = divmod(self.counter[0], self.counter_max) @@ -684,9 +688,7 @@ class Random123GeneratorBase: b = kwargs.pop("b", 1) result = cl_array.empty(*args, **kwargs) - - result.add_event( - self.fill_uniform(result, queue=result.queue, a=a, b=b)) + self.fill_uniform(result, queue=result.queue, a=a, b=b) return result def fill_normal(self, ary, mu=0, sigma=1, queue=None): @@ -703,9 +705,7 @@ class Random123GeneratorBase: sigma = kwargs.pop("sigma", 1) result = cl_array.empty(*args, **kwargs) - - result.add_event( - self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma)) + self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma) return result -- GitLab