diff --git a/pyopencl/array.py b/pyopencl/array.py
index d777eda042543f3d3e140d11d1fe032460247566..2877ac56ea37bbb2a8fe583e46749bf58c34feef 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -1183,18 +1183,27 @@ class Array(object):
         old_itemsize = self.dtype.itemsize
         itemsize = np.dtype(dtype).itemsize
 
-        if self.shape[-1] * old_itemsize % itemsize != 0:
+        from pytools import argmin2
+        min_stride_axis = argmin2(
+                (axis, abs(stride))
+                for axis, stride in enumerate(self.strides))
+
+        if self.shape[min_stride_axis] * old_itemsize % itemsize != 0:
             raise ValueError("new type not compatible with array")
 
-        shape = self.shape[:-1] + (self.shape[-1] * old_itemsize // itemsize,)
-        strides = tuple(
-                s * itemsize // old_itemsize
-                for s in self.strides)
+        new_shape = (
+                self.shape[:min_stride_axis]
+                + (self.shape[min_stride_axis] * old_itemsize // itemsize,)
+                + self.shape[min_stride_axis+1:])
+        new_strides = (
+                self.strides[:min_stride_axis]
+                + (self.strides[min_stride_axis] * itemsize // old_itemsize,)
+                + self.strides[min_stride_axis+1:])
 
         return self._new_with_changes(
                 self.base_data, self.offset,
-                shape=shape, dtype=dtype,
-                strides=strides)
+                shape=new_shape, dtype=dtype,
+                strides=new_strides)
 
     # }}}
 
diff --git a/test/test_array.py b/test/test_array.py
index 3125ac69a1d8d3cb8619f77518ba8df7c37140dd..b70deb509d9c8e50d4fe09b63d11a6748eef6e75 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -694,6 +694,24 @@ def test_map_to_host(ctx_factory):
     assert (a_host_saved == a_dev.get()).all()
 
 
+def test_view_and_strides(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    from pyopencl.clrandom import rand as clrand
+
+    X = clrand(queue, (5, 10), dtype=np.float32)
+    Y = X[:3, :5]
+    y = Y.view()
+
+    assert y.shape == Y.shape
+    assert y.strides == Y.strides
+
+    import pytest
+    with pytest.raises(AssertionError):
+        assert (y.get() == X.get()[:3, :5]).all()
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the
     # tests.