From 5a0947508a65deb3623ff3863fb35f511975d643 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 18 Aug 2015 11:50:28 -0500 Subject: [PATCH] Implement Array.squeeze (based on work by Simon Perkins for PyCUDA) --- pyopencl/array.py | 15 +++++++++++++++ test/test_array.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/pyopencl/array.py b/pyopencl/array.py index a9259fc5..647bcf7a 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -371,6 +371,7 @@ class Array(object): .. automethod :: reshape .. automethod :: ravel .. automethod :: view + .. automethod :: squeeze .. automethod :: transpose .. attribute :: T .. automethod :: set @@ -1444,6 +1445,20 @@ class Array(object): shape=new_shape, dtype=dtype, strides=new_strides) + def squeeze(self): + """Returns a view of the array with dimensions of + length 1 removed. + + .. versionadded:: 2015.2 + """ + new_shape = tuple([dim for dim in self.shape if dim > 1]) + new_strides = tuple([self.strides[i] + for i, dim in enumerate(self.shape) if dim > 1]) + + return self._new_with_changes( + self.base_data, self.offset, + shape=new_shape, strides=new_strides) + def transpose(self, axes=None): """Permute the dimensions of an array. diff --git a/test/test_array.py b/test/test_array.py index f3b6a668..333d1337 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -873,6 +873,41 @@ def test_newaxis(ctx_factory): assert b_gpu.strides[i] == b.strides[i] +def test_squeeze(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + shape = (40, 2, 5, 100) + a_cpu = np.random.random(size=shape) + a_gpu = cl_array.to_device(queue, a_cpu) + + # Slice with length 1 on dimensions 0 and 1 + a_gpu_slice = a_gpu[0:1, 1:2, :, :] + assert a_gpu_slice.shape == (1, 1, shape[2], shape[3]) + assert a_gpu_slice.flags.c_contiguous is False + + # Squeeze it and obtain contiguity + a_gpu_squeezed_slice = a_gpu[0:1, 1:2, :, :].squeeze() + assert a_gpu_squeezed_slice.shape == (shape[2], shape[3]) + assert a_gpu_squeezed_slice.flags.c_contiguous is True + + # Check that we get the original values out + #assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel()) + + # Slice with length 1 on dimensions 2 + a_gpu_slice = a_gpu[:, :, 2:3, :] + assert a_gpu_slice.shape == (shape[0], shape[1], 1, shape[3]) + assert a_gpu_slice.flags.c_contiguous is False + + # Squeeze it, but no contiguity here + a_gpu_squeezed_slice = a_gpu[:, :, 2:3, :].squeeze() + assert a_gpu_squeezed_slice.shape == (shape[0], shape[1], shape[3]) + assert a_gpu_squeezed_slice.flags.c_contiguous is False + + # Check that we get the original values out + #assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel()) + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the # tests. -- GitLab