From 24f559f9f88a2731c8f4d09d46fcf842366e0bd9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 24 Apr 2013 16:22:25 -0400 Subject: [PATCH] Minor Array.reshape fixes. --- pyopencl/array.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index 65c8f008..1c8a0289 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -207,6 +207,14 @@ class DefaultAllocator(cl.tools.DeferredAllocator): DeprecationWarning, 2) cl.tools.DeferredAllocator.__init__(self, *args, **kwargs) +def _make_strides(itemsize, shape, order): + if order == "F": + return _f_contiguous_strides(itemsize, shape) + elif order == "C": + return _c_contiguous_strides(itemsize, shape) + else: + raise ValueError("invalid order: %s" % order) + # }}} # {{{ array class @@ -344,14 +352,7 @@ class Array(object): s = np.asscalar(s) if strides is None: - if order == "F": - strides = _f_contiguous_strides( - dtype.itemsize, shape) - elif order == "C": - strides = _c_contiguous_strides( - dtype.itemsize, shape) - else: - raise ValueError("invalid order: %s" % order) + strides = _make_strides(dtype.itemsize, shape, order) else: # FIXME: We should possibly perform some plausibility @@ -910,8 +911,14 @@ class Array(object): # {{{ views - def reshape(self, *shape): + def reshape(self, *shape, **kwargs): """Returns an array containing the same data with a new shape.""" + + order = kwargs.pop("order", "C") + if kwargs: + raise TypeError("unexpected keyword arguments: %s" + % kwargs.keys()) + # TODO: add more error-checking, perhaps if isinstance(shape[0], tuple) or isinstance(shape[0], list): shape = tuple(shape[0]) @@ -919,7 +926,8 @@ class Array(object): if size != self.size: raise ValueError("total size of new array must be unchanged") - return self._new_with_changes(data=self.data, shape=shape) + return self._new_with_changes(data=self.data, shape=shape, + strides=_make_strides(self.dtype.itemsize, shape, order)) def ravel(self): """Returns flattened array containing the same data.""" -- GitLab