From f3ddebeb0904de68faccead3a3fc468fbd486328 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 17 Aug 2013 19:36:26 -0500
Subject: [PATCH] Test, slightly modify strides treatment in Array.view()

---
 pyopencl/array.py  | 23 ++++++++++++++++-------
 test/test_array.py | 18 ++++++++++++++++++
 2 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index d777eda0..2877ac56 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -1183,18 +1183,27 @@ class Array(object):
         old_itemsize = self.dtype.itemsize
         itemsize = np.dtype(dtype).itemsize
 
-        if self.shape[-1] * old_itemsize % itemsize != 0:
+        from pytools import argmin2
+        min_stride_axis = argmin2(
+                (axis, abs(stride))
+                for axis, stride in enumerate(self.strides))
+
+        if self.shape[min_stride_axis] * old_itemsize % itemsize != 0:
             raise ValueError("new type not compatible with array")
 
-        shape = self.shape[:-1] + (self.shape[-1] * old_itemsize // itemsize,)
-        strides = tuple(
-                s * itemsize // old_itemsize
-                for s in self.strides)
+        new_shape = (
+                self.shape[:min_stride_axis]
+                + (self.shape[min_stride_axis] * old_itemsize // itemsize,)
+                + self.shape[min_stride_axis+1:])
+        new_strides = (
+                self.strides[:min_stride_axis]
+                + (self.strides[min_stride_axis] * itemsize // old_itemsize,)
+                + self.strides[min_stride_axis+1:])
 
         return self._new_with_changes(
                 self.base_data, self.offset,
-                shape=shape, dtype=dtype,
-                strides=strides)
+                shape=new_shape, dtype=dtype,
+                strides=new_strides)
 
     # }}}
 
diff --git a/test/test_array.py b/test/test_array.py
index 3125ac69..b70deb50 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -694,6 +694,24 @@ def test_map_to_host(ctx_factory):
     assert (a_host_saved == a_dev.get()).all()
 
 
+def test_view_and_strides(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    from pyopencl.clrandom import rand as clrand
+
+    X = clrand(queue, (5, 10), dtype=np.float32)
+    Y = X[:3, :5]
+    y = Y.view()
+
+    assert y.shape == Y.shape
+    assert y.strides == Y.strides
+
+    import pytest
+    with pytest.raises(AssertionError):
+        assert (y.get() == X.get()[:3, :5]).all()
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the
     # tests.
-- 
GitLab