From f3ddebeb0904de68faccead3a3fc468fbd486328 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sat, 17 Aug 2013 19:36:26 -0500 Subject: [PATCH] Test, slightly modify strides treatment in Array.view() --- pyopencl/array.py | 23 ++++++++++++++++------- test/test_array.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index d777eda0..2877ac56 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1183,18 +1183,27 @@ class Array(object): old_itemsize = self.dtype.itemsize itemsize = np.dtype(dtype).itemsize - if self.shape[-1] * old_itemsize % itemsize != 0: + from pytools import argmin2 + min_stride_axis = argmin2( + (axis, abs(stride)) + for axis, stride in enumerate(self.strides)) + + if self.shape[min_stride_axis] * old_itemsize % itemsize != 0: raise ValueError("new type not compatible with array") - shape = self.shape[:-1] + (self.shape[-1] * old_itemsize // itemsize,) - strides = tuple( - s * itemsize // old_itemsize - for s in self.strides) + new_shape = ( + self.shape[:min_stride_axis] + + (self.shape[min_stride_axis] * old_itemsize // itemsize,) + + self.shape[min_stride_axis+1:]) + new_strides = ( + self.strides[:min_stride_axis] + + (self.strides[min_stride_axis] * itemsize // old_itemsize,) + + self.strides[min_stride_axis+1:]) return self._new_with_changes( self.base_data, self.offset, - shape=shape, dtype=dtype, - strides=strides) + shape=new_shape, dtype=dtype, + strides=new_strides) # }}} diff --git a/test/test_array.py b/test/test_array.py index 3125ac69..b70deb50 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -694,6 +694,24 @@ def test_map_to_host(ctx_factory): assert (a_host_saved == a_dev.get()).all() +def test_view_and_strides(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.clrandom import rand as clrand + + X = clrand(queue, (5, 10), dtype=np.float32) + Y = X[:3, :5] + y = Y.view() + + assert y.shape == Y.shape + assert y.strides == Y.strides + + import pytest + with pytest.raises(AssertionError): + assert (y.get() == X.get()[:3, :5]).all() + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the # tests. -- GitLab