From 51d03c14697b9d178d6246cc9233c591862363ba Mon Sep 17 00:00:00 2001
From: zachjweiner <zachjweiner@gmail.com>
Date: Tue, 27 Aug 2019 10:03:23 -0500
Subject: [PATCH] add async tests

---
 pyopencl/array.py  |  4 ++--
 test/test_array.py | 26 ++++++++++++++++++++++++++
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index c1b132cc..72d2a0cf 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -726,8 +726,8 @@ class Array(object):
 
     def get_async(self, queue=None, ary=None, **kwargs):
         """
-        Asynchronous version of :meth:`get`, following the same calling convention
-        while returning a tuple ``(ary, event)`` containing the host array `ary`
+        Asynchronous version of :meth:`get` which returns a tuple ``(ary, event)``
+        containing the host array `ary`
         and the :class:`pyopencl.NannyEvent` `event` returned by
         :meth:`pyopencl.enqueue_copy`.
         """
diff --git a/test/test_array.py b/test/test_array.py
index 02e43e24..cf63fc14 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -1217,6 +1217,32 @@ def test_multi_put(ctx_factory):
     assert np.all(np.all(out_compare[i] == out_arrays[i].get()) for i in range(9))
 
 
+def test_get_async(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    a = np.random.rand(10**6).astype(np.dtype('float32'))
+    a_gpu = cl_array.to_device(queue, a)
+    b = a + a**5 + 1
+    b_gpu = a_gpu + a_gpu**5 + 1
+
+    # deprecated, but still test
+    b1 = b_gpu.get(async_=True)  # testing that this waits for events
+    b_gpu.finish()
+    assert np.abs(b1 - b).mean() < 1e-5
+
+    b1 = b_gpu.get_async()  # testing that this waits for events
+    b_gpu.finish()
+    assert np.abs(b1 - b).mean() < 1e-5
+
+    wait_event = cl.UserEvent(context)
+    b_gpu.add_event(wait_event)
+    b = b_gpu.get_async()  # testing that this doesn't hang
+    wait_event.set_status(cl.command_execution_status.COMPLETE)
+    b_gpu.finish()
+    assert np.abs(b1 - b).mean() < 1e-5
+
+
 def test_outoforderqueue_get(ctx_factory):
     context = ctx_factory()
     try:
-- 
GitLab