diff --git a/pyopencl/array.py b/pyopencl/array.py index 08f7cc8b969c60b22df0565806d3db30f9126b4b..326eea7c0eb90656884fda1f28afba0e364761b7 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -631,14 +631,29 @@ class Array(object): return ary - def copy(self, queue=None): - """.. versionadded:: 2013.1""" + def copy(self, queue=_copy_queue): + """ + :arg queue: The :class:`CommandQueue` for the returned array. - queue = queue or self.queue - result = self._new_like_me() + .. versionchanged:: 2017.1.2 + Updates the queue of the returned array. + + .. versionadded:: 2013.1 + """ + + if queue is _copy_queue: + queue = self.queue + + result = self._new_like_me(queue=queue) + + # result.queue won't be the same as queue if queue is None. + # We force them to be the same here. + if result.queue is not queue: + result = result.with_queue(queue) if self.nbytes: - cl.enqueue_copy(queue, result.base_data, self.base_data, + cl.enqueue_copy(queue or self.queue, + result.base_data, self.base_data, src_offset=self.offset, byte_count=self.nbytes) return result @@ -1782,6 +1797,8 @@ class Array(object): # }}} +# {{{ creation helpers + def as_strided(ary, shape=None, strides=None): """Make an :class:`Array` from the given array with the given shape and strides. @@ -1797,10 +1814,6 @@ def as_strided(ary, shape=None, strides=None): return Array(ary.queue, shape, ary.dtype, allocator=ary.allocator, data=ary.data, strides=strides) -# }}} - - -# {{{ creation helpers class _same_as_transfer(object): # noqa pass diff --git a/pyopencl/version.py b/pyopencl/version.py index cd1ab88fa776cf4debbe070f4f8b0b9c372168bf..751a8cfcb1695d46d5374083d140b09f9ee9d8e0 100644 --- a/pyopencl/version.py +++ b/pyopencl/version.py @@ -1,3 +1,3 @@ -VERSION = (2017, 1, 1) +VERSION = (2017, 1, 2) VERSION_STATUS = "" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS diff --git a/test/test_array.py b/test/test_array.py index 4c41890303278dc1490528e74f1a0e147d1bbbd5..115488c3e3f9c8d5689a921bd9bc790abe0ae968 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -762,6 +762,32 @@ def test_diff(ctx_factory): (cl.array.diff(a_dev).get() - np.diff(a))) assert err < 1e-4 + +def test_copy(ctx_factory): + context = ctx_factory() + queue1 = cl.CommandQueue(context) + queue2 = cl.CommandQueue(context) + + # Test copy + + arr = cl.array.zeros(queue1, 100, np.int32) + arr_copy = arr.copy() + + assert (arr == arr_copy).all().get() + assert arr.data != arr_copy.data + assert arr_copy.queue is queue1 + + # Test queue association + + arr_copy = arr.copy(queue=queue2) + assert arr_copy.queue is queue2 + + arr_copy = arr.copy(queue=None) + assert arr_copy.queue is None + + arr_copy = arr.with_queue(None).copy(queue=queue1) + assert arr_copy.queue is queue1 + # }}}