From ecb964b25a5731ce9ae5a706e655d1cb2146d847 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 31 Dec 2020 12:42:17 -0600
Subject: [PATCH] Fix RNG event management

---
 pyopencl/clrandom.py | 32 ++++++++++++++++----------------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py
index 2471d8d2..91d447e0 100644
--- a/pyopencl/clrandom.py
+++ b/pyopencl/clrandom.py
@@ -334,10 +334,13 @@ class RanluxGenerator:
             queue = ary.queue
 
         knl, size_multiplier = self.get_gen_kernel(ary.dtype, "uniform")
-        return knl(queue,
+        evt = knl(queue,
                 (self.num_work_items,), None,
                 self.state.data, ary.data, ary.size*size_multiplier,
-                b-a, a)
+                b-a, a, wait_for=ary.event)
+        ary.add_event(evt)
+        self.state.add_event(evt)
+        return ary
 
     def uniform(self, *args, **kwargs):
         """Make a new empty array, apply :meth:`fill_uniform` to it.
@@ -346,9 +349,7 @@ class RanluxGenerator:
         b = kwargs.pop("b", 1)
 
         result = cl_array.empty(*args, **kwargs)
-
-        result.add_event(
-                self.fill_uniform(result, queue=result.queue, a=a, b=b))
+        self.fill_uniform(result, queue=result.queue, a=a, b=b)
         return result
 
     def fill_normal(self, ary, mu=0, sigma=1, queue=None):
@@ -364,9 +365,13 @@ class RanluxGenerator:
             queue = ary.queue
 
         knl, size_multiplier = self.get_gen_kernel(ary.dtype, "normal")
-        return knl(queue,
+        evt = knl(queue,
                 (self.num_work_items,), self.wg_size,
-                self.state.data, ary.data, ary.size*size_multiplier, sigma, mu)
+                self.state.data, ary.data, ary.size*size_multiplier, sigma, mu,
+                wait_for=ary.events)
+        ary.add_event(evt)
+        self.state.add_event(evt)
+        return evt
 
     def normal(self, *args, **kwargs):
         """Make a new empty array, apply :meth:`fill_normal` to it.
@@ -375,9 +380,7 @@ class RanluxGenerator:
         sigma = kwargs.pop("sigma", 1)
 
         result = cl_array.empty(*args, **kwargs)
-
-        result.add_event(
-                self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma))
+        self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma)
         return result
 
     @memoize_method
@@ -663,6 +666,7 @@ class Random123GeneratorBase:
         gsize, lsize = splay(queue, ary.size)
 
         evt = knl(queue, gsize, lsize, *args)
+        ary.add_event(evt)
 
         self.counter[0] += n * counter_multiplier
         c1_incr, self.counter[0] = divmod(self.counter[0], self.counter_max)
@@ -684,9 +688,7 @@ class Random123GeneratorBase:
         b = kwargs.pop("b", 1)
 
         result = cl_array.empty(*args, **kwargs)
-
-        result.add_event(
-                self.fill_uniform(result, queue=result.queue, a=a, b=b))
+        self.fill_uniform(result, queue=result.queue, a=a, b=b)
         return result
 
     def fill_normal(self, ary, mu=0, sigma=1, queue=None):
@@ -703,9 +705,7 @@ class Random123GeneratorBase:
         sigma = kwargs.pop("sigma", 1)
 
         result = cl_array.empty(*args, **kwargs)
-
-        result.add_event(
-                self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma))
+        self.fill_normal(result, queue=result.queue, mu=mu, sigma=sigma)
         return result
 
 
-- 
GitLab