From 3dcf4b5fbea2b48a4bc0a2b7b697a13b82378faf Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 29 May 2013 08:10:19 -0400
Subject: [PATCH] Add support for ellipsis to __getitem__.

---
 pyopencl/array.py | 80 +++++++++++++++++++++++++++++++++--------------
 1 file changed, 57 insertions(+), 23 deletions(-)

diff --git a/pyopencl/array.py b/pyopencl/array.py
index df67c17d..d8517fea 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -1061,35 +1061,69 @@ class Array(object):
         if not isinstance(index, tuple):
             index = (index,)
 
-        if len(index) > len(self.shape):
-            raise IndexError("too many axes in index (have: %d, got: %d)" % (len(self.shape), len(index)))
-
         new_shape = []
         new_offset = self.offset
         new_strides = []
 
-        for i, (subidx, shape_i, strides_i) in enumerate(
-                zip(index, self.shape, self.strides)):
-            if isinstance(subidx, slice):
-                start, stop, stride = subidx.indices(shape_i)
-                new_shape.append((stop-start)//stride)
-                new_strides.append(stride*strides_i)
-                new_offset += strides_i*start
-            elif isinstance(subidx, (int, np.integer)):
-                if subidx < 0:
-                    subidx += shape_i
-
-                if not (0 <= subidx < shape_i):
-                    raise IndexError("subindex in axis %d out of range" % i)
-
-                new_offset += strides_i*subidx
+        seen_ellipsis = False
+
+        index_axis = 0
+        array_axis = 0
+        while index_axis < len(index):
+            index_entry = index[index_axis]
+
+            if array_axis > len(self.shape):
+                raise IndexError("too many axes in index")
+
+            if isinstance(index_entry, slice):
+                start, stop, idx_stride = index_entry.indices(self.shape[array_axis])
+
+                array_stride = self.strides[array_axis]
+
+                new_shape.append((stop-start)//idx_stride)
+                new_strides.append(idx_stride*array_stride)
+                new_offset += array_stride*start
+
+                index_axis += 1
+                array_axis += 1
+
+            elif isinstance(index_entry, (int, np.integer)):
+                array_shape = self.shape[array_axis]
+                if index_entry < 0:
+                    index_entry += array_shape
+
+                if not (0 <= index_entry < array_shape):
+                    raise IndexError("subindex in axis %d out of range" % index_axis)
+
+                new_offset += self.strides[array_axis]*index_entry
+
+                index_axis += 1
+                array_axis += 1
+
+            elif index_entry is Ellipsis:
+                index_axis += 1
+
+                remaining_index_count = len(index) - index_axis
+                new_array_axis = len(self.shape) - remaining_index_count
+                if new_array_axis < array_axis:
+                    raise IndexError("invalid use of ellipsis in index")
+                while array_axis < new_array_axis:
+                    new_shape.append(self.shape[array_axis])
+                    new_strides.append(self.strides[array_axis])
+                    array_axis += 1
+
+                if seen_ellipsis:
+                    raise IndexError("more than one ellipsis not allowed in index")
+                seen_ellipsis = True
+
             else:
-                raise IndexError("invalid subindex in axis %d" % i)
+                raise IndexError("invalid subindex in axis %d" % index_axis)
+
+        while array_axis < len(self.shape):
+            new_shape.append(self.shape[array_axis])
+            new_strides.append(self.strides[array_axis])
 
-        while i + 1 < len(self.shape):
-            i += 1
-            new_shape.append(self.shape[i])
-            new_strides.append(self.strides[i])
+            array_axis += 1
 
         return self._new_with_changes(
                 data=self.base_data,
-- 
GitLab