From 5a0947508a65deb3623ff3863fb35f511975d643 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 18 Aug 2015 11:50:28 -0500
Subject: [PATCH] Implement Array.squeeze (based on work by Simon Perkins for
 PyCUDA)

---
 pyopencl/array.py  | 15 +++++++++++++++
 test/test_array.py | 35 +++++++++++++++++++++++++++++++++++
 2 files changed, 50 insertions(+)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index a9259fc5..647bcf7a 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -371,6 +371,7 @@ class Array(object):
     .. automethod :: reshape
     .. automethod :: ravel
     .. automethod :: view
+    .. automethod :: squeeze
     .. automethod :: transpose
     .. attribute :: T
     .. automethod :: set
@@ -1444,6 +1445,20 @@ class Array(object):
                 shape=new_shape, dtype=dtype,
                 strides=new_strides)
 
+    def squeeze(self):
+        """Returns a view of the array with dimensions of
+        length 1 removed.
+
+        .. versionadded:: 2015.2
+        """
+        new_shape = tuple([dim for dim in self.shape if dim > 1])
+        new_strides = tuple([self.strides[i]
+            for i, dim in enumerate(self.shape) if dim > 1])
+
+        return self._new_with_changes(
+                self.base_data, self.offset,
+                shape=new_shape, strides=new_strides)
+
     def transpose(self, axes=None):
         """Permute the dimensions of an array.
 
diff --git a/test/test_array.py b/test/test_array.py
index f3b6a668..333d1337 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -873,6 +873,41 @@ def test_newaxis(ctx_factory):
             assert b_gpu.strides[i] == b.strides[i]
 
 
+def test_squeeze(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    shape = (40, 2, 5, 100)
+    a_cpu = np.random.random(size=shape)
+    a_gpu = cl_array.to_device(queue, a_cpu)
+
+    # Slice with length 1 on dimensions 0 and 1
+    a_gpu_slice = a_gpu[0:1, 1:2, :, :]
+    assert a_gpu_slice.shape == (1, 1, shape[2], shape[3])
+    assert a_gpu_slice.flags.c_contiguous is False
+
+    # Squeeze it and obtain contiguity
+    a_gpu_squeezed_slice = a_gpu[0:1, 1:2, :, :].squeeze()
+    assert a_gpu_squeezed_slice.shape == (shape[2], shape[3])
+    assert a_gpu_squeezed_slice.flags.c_contiguous is True
+
+    # Check that we get the original values out
+    #assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel())
+
+    # Slice with length 1 on dimensions 2
+    a_gpu_slice = a_gpu[:, :, 2:3, :]
+    assert a_gpu_slice.shape == (shape[0], shape[1], 1, shape[3])
+    assert a_gpu_slice.flags.c_contiguous is False
+
+    # Squeeze it, but no contiguity here
+    a_gpu_squeezed_slice = a_gpu[:, :, 2:3, :].squeeze()
+    assert a_gpu_squeezed_slice.shape == (shape[0], shape[1], shape[3])
+    assert a_gpu_squeezed_slice.flags.c_contiguous is False
+
+    # Check that we get the original values out
+    #assert np.all(a_gpu_slice.get().ravel() == a_gpu_squeezed_slice.get().ravel())
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the
     # tests.
-- 
GitLab