From 24f559f9f88a2731c8f4d09d46fcf842366e0bd9 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 24 Apr 2013 16:22:25 -0400
Subject: [PATCH] Minor Array.reshape fixes.

---
 pyopencl/array.py | 28 ++++++++++++++++++----------
 1 file changed, 18 insertions(+), 10 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index 65c8f008..1c8a0289 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -207,6 +207,14 @@ class DefaultAllocator(cl.tools.DeferredAllocator):
                 DeprecationWarning, 2)
         cl.tools.DeferredAllocator.__init__(self, *args, **kwargs)
 
+def _make_strides(itemsize, shape, order):
+    if order == "F":
+        return _f_contiguous_strides(itemsize, shape)
+    elif order == "C":
+        return _c_contiguous_strides(itemsize, shape)
+    else:
+        raise ValueError("invalid order: %s" % order)
+
 # }}}
 
 # {{{ array class
@@ -344,14 +352,7 @@ class Array(object):
             s = np.asscalar(s)
 
         if strides is None:
-            if order == "F":
-                strides = _f_contiguous_strides(
-                        dtype.itemsize, shape)
-            elif order == "C":
-                strides = _c_contiguous_strides(
-                        dtype.itemsize, shape)
-            else:
-                raise ValueError("invalid order: %s" % order)
+            strides = _make_strides(dtype.itemsize, shape, order)
 
         else:
             # FIXME: We should possibly perform some plausibility
@@ -910,8 +911,14 @@ class Array(object):
 
     # {{{ views
 
-    def reshape(self, *shape):
+    def reshape(self, *shape, **kwargs):
         """Returns an array containing the same data with a new shape."""
+
+        order = kwargs.pop("order", "C")
+        if kwargs:
+            raise TypeError("unexpected keyword arguments: %s"
+                    % kwargs.keys())
+
         # TODO: add more error-checking, perhaps
         if isinstance(shape[0], tuple) or isinstance(shape[0], list):
             shape = tuple(shape[0])
@@ -919,7 +926,8 @@ class Array(object):
         if size != self.size:
             raise ValueError("total size of new array must be unchanged")
 
-        return self._new_with_changes(data=self.data, shape=shape)
+        return self._new_with_changes(data=self.data, shape=shape,
+                strides=_make_strides(self.dtype.itemsize, shape, order))
 
     def ravel(self):
         """Returns flattened array containing the same data."""
-- 
GitLab