diff --git a/pyopencl/array.py b/pyopencl/array.py index 24c6ded6c4a9c3b5943f81a696aa0b3005e5efd0..3805d2d8e82b814e0c16b7258d1b2015b1753be7 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -305,6 +305,7 @@ class Array(object): .. attribute :: T .. automethod :: set .. automethod :: get + .. automethod :: get_async .. automethod :: copy .. automethod :: __str__ @@ -627,23 +628,7 @@ class Array(object): is_blocking=not async_) self.add_event(event1) - def get(self, queue=None, ary=None, async_=None, **kwargs): - """Transfer the contents of *self* into *ary* or a newly allocated - :mod:`numpy.ndarray`. If *ary* is given, it must have the same - shape and dtype. - - .. versionchanged:: 2015.2 - - *ary* with different shape was deprecated. - - .. versionchanged:: 2017.2.1 - - Python 3.7 makes ``async`` a reserved keyword. On older Pythons, - we will continue to accept *async* as a parameter, however this - should be considered deprecated. *async_* is the new, official - spelling. - """ - + def _get(self, queue=None, ary=None, async_=None, **kwargs): # {{{ handle 'async' deprecation async_arg = kwargs.pop("async", None) @@ -688,12 +673,63 @@ class Array(object): "to associate one.") if self.size: - cl.enqueue_copy(queue, ary, self.base_data, + event1 = cl.enqueue_copy(queue, ary, self.base_data, device_offset=self.offset, wait_for=self.events, is_blocking=not async_) + self.add_event(event1) + else: + event1 = None + + return ary, event1 + + def get(self, queue=None, ary=None, async_=None, **kwargs): + """Transfer the contents of *self* into *ary* or a newly allocated + :mod:`numpy.ndarray`. If *ary* is given, it must have the same + shape and dtype. + + .. versionchanged:: 2019.1.2 + + Calling with `async_=True` was deprecated and replaced by + :meth:`get_async`. + The event returned by :meth:`pyopencl.enqueue_copy` is now stored into + :attr:`events` to ensure data is not modified before the copy is + complete. + + .. versionchanged:: 2015.2 + + *ary* with different shape was deprecated. + + .. versionchanged:: 2017.2.1 + + Python 3.7 makes ``async`` a reserved keyword. On older Pythons, + we will continue to accept *async* as a parameter, however this + should be considered deprecated. *async_* is the new, official + spelling. + """ + + if async_: + from warnings import warn + warn("calling pyopencl.Array.get with `async_=True` is deprecated. " + "Please use pyopencl.Array.get_async for asynchronous " + "device-to-host transfers", + DeprecationWarning, 2) + + ary, event1 = self._get(queue=queue, ary=ary, async_=async_, **kwargs) return ary + def get_async(self, queue=None, ary=None, **kwargs): + """ + Asynchronous version of :meth:`get` which returns a tuple ``(ary, event)`` + containing the host array `ary` + and the :class:`pyopencl.NannyEvent` `event` returned by + :meth:`pyopencl.enqueue_copy`. + + .. versionadded:: 2019.1.2 + """ + + return self._get(queue=queue, ary=ary, async_=True, **kwargs) + def copy(self, queue=_copy_queue): """ :arg queue: The :class:`CommandQueue` for the returned array. diff --git a/test/test_array.py b/test/test_array.py index 02e43e2481f4b3bdcdbc25f4c682dbd0043052b0..e9fb2ddd1d4ae2aaf16f18a2696666b607970056 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1217,6 +1217,32 @@ def test_multi_put(ctx_factory): assert np.all(np.all(out_compare[i] == out_arrays[i].get()) for i in range(9)) +def test_get_async(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + a = np.random.rand(10**6).astype(np.dtype('float32')) + a_gpu = cl_array.to_device(queue, a) + b = a + a**5 + 1 + b_gpu = a_gpu + a_gpu**5 + 1 + + # deprecated, but still test + b1 = b_gpu.get(async_=True) # testing that this waits for events + b_gpu.finish() + assert np.abs(b1 - b).mean() < 1e-5 + + b1, evt = b_gpu.get_async() # testing that this waits for events + evt.wait() + assert np.abs(b1 - b).mean() < 1e-5 + + wait_event = cl.UserEvent(context) + b_gpu.add_event(wait_event) + b, evt = b_gpu.get_async() # testing that this doesn't hang + wait_event.set_status(cl.command_execution_status.COMPLETE) + evt.wait() + assert np.abs(b1 - b).mean() < 1e-5 + + def test_outoforderqueue_get(ctx_factory): context = ctx_factory() try: