diff --git a/doc/source/array.rst b/doc/source/array.rst index bd9ce90c43cb8b93af0dbff3f42c9f873616828d..663be68387588e6cbd0760f6dd7c54eda49c6321 100644 --- a/doc/source/array.rst +++ b/doc/source/array.rst @@ -90,6 +90,19 @@ The :class:`Array` Class Returns the size of the leading dimension of *self*. + .. method :: reshape(shape) + + Returns an array containing the same data with a new shape. + + .. method :: ravel() + + Returns flattened array containing the same data. + + .. metod :: view(dtype=None) + + Returns view of array with the same data. If *dtype* is different from + current dtype, the actual bytes of memory will be reinterpreted. + .. method :: set(ary, queue=None, async=False) Transfer the contents the :class:`numpy.ndarray` object *ary* diff --git a/pyopencl/array.py b/pyopencl/array.py index e7a441e8743cfda776149e347db29d48e722c004..0cd46d60be3910a3c1aac6509a2dec8198fc448c 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -274,9 +274,6 @@ class Array(object): self.data = self.allocator(self.size * self.dtype.itemsize) else: self.data = None - - if base is not None: - raise ValueError("If data is specified, base must be None.") else: self.data = data @@ -289,6 +286,28 @@ class Array(object): def flags(self): return _ArrayFlags(self) + def _new_with_changes(self, data, shape=None, dtype=None, strides=None, queue=None, + base=None): + if shape is None: + shape = self.shape + if dtype is None: + dtype = self.dtype + if strides is None: + strides = self.strides + if queue is None: + queue = self.queue + if base is None and data is self.data: + base = self + + if queue is not None: + return Array(queue, shape, dtype, + allocator=self.allocator, strides=strides, base=base) + elif self.allocator is not None: + return Array(self.allocator, shape, dtype, queue=queue, + strides=strides, base=base) + else: + return Array(self.context, shape, dtype, strides=strides, base=base) + #@memoize_method FIXME: reenable def get_sizes(self, queue): return splay(queue, self.mem_size) @@ -641,6 +660,37 @@ class Array(object): def __gt__(self, other): raise NotImplementedError + # {{{ views + + def reshape(self, *shape): + # TODO: add more error-checking, perhaps + if isinstance(shape[0], tuple) or isinstance(shape[0], list): + shape = tuple(shape[0]) + size = reduce(lambda x, y: x * y, shape, 1) + if size != self.size: + raise ValueError("total size of new array must be unchanged") + + return self._new_with_changes(data=self.data, shape=shape) + + def ravel(self): + return self.reshape(self.size) + + def view(self, dtype=None): + if dtype is None: + dtype = self.dtype + + old_itemsize = self.dtype.itemsize + itemsize = np.dtype(dtype).itemsize + + if self.shape[-1] * old_itemsize % itemsize != 0: + raise ValueError("new type not compatible with array") + + shape = self.shape[:-1] + (self.shape[-1] * old_itemsize // itemsize,) + + return self._new_with_changes(data=self.data, shape=shape, dtype=dtype) + + # }} + # }}} # {{{ creation helpers @@ -691,15 +741,7 @@ def zeros(*args, **kwargs): return _zeros(*args, **kwargs) def empty_like(ary): - if ary.queue is not None: - return Array(ary.queue, ary.shape, ary.dtype, - allocator=ary.allocator, strides=ary.strides) - elif ary.allocator is not None: - return Array(ary.allocator, ary.shape, ary.dtype, queue=ary.queue, - strides=ary.strides) - else: - return Array(ary.context, ary.shape, ary.dtype, - strides=ary.strides) + return ary._new_with_changes(data=None) def zeros_like(ary): result = empty_like(ary) diff --git a/test/test_array.py b/test/test_array.py index c0015d8779120a4bba35e098c94632b495bf4e04..8a6d4213b70fde452f3e4d2f65cd10c20040330d 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -636,6 +636,28 @@ def test_mem_pool_with_arrays(ctx_factory): assert b_dev.allocator is mem_pool assert result.allocator is mem_pool +@pytools.test.mark_test.opencl +def test_view(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + a = np.arange(128).reshape(8, 16).astype(np.float32) + a_dev = cl_array.to_device(queue, a) + + # same dtype + view = a_dev.view() + assert view.shape == a_dev.shape and view.dtype == a_dev.dtype + + # larger dtype + view = a_dev.view(np.complex64) + assert view.shape == (8, 8) and view.dtype == np.complex64 + + # smaller dtype + view = a_dev.view(np.int16) + assert view.shape == (8, 32) and view.dtype == np.int16 + + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the