From 3dcf4b5fbea2b48a4bc0a2b7b697a13b82378faf Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 29 May 2013 08:10:19 -0400 Subject: [PATCH] Add support for ellipsis to __getitem__. --- pyopencl/array.py | 80 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/pyopencl/array.py b/pyopencl/array.py index df67c17d..d8517fea 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1061,35 +1061,69 @@ class Array(object): if not isinstance(index, tuple): index = (index,) - if len(index) > len(self.shape): - raise IndexError("too many axes in index (have: %d, got: %d)" % (len(self.shape), len(index))) - new_shape = [] new_offset = self.offset new_strides = [] - for i, (subidx, shape_i, strides_i) in enumerate( - zip(index, self.shape, self.strides)): - if isinstance(subidx, slice): - start, stop, stride = subidx.indices(shape_i) - new_shape.append((stop-start)//stride) - new_strides.append(stride*strides_i) - new_offset += strides_i*start - elif isinstance(subidx, (int, np.integer)): - if subidx < 0: - subidx += shape_i - - if not (0 <= subidx < shape_i): - raise IndexError("subindex in axis %d out of range" % i) - - new_offset += strides_i*subidx + seen_ellipsis = False + + index_axis = 0 + array_axis = 0 + while index_axis < len(index): + index_entry = index[index_axis] + + if array_axis > len(self.shape): + raise IndexError("too many axes in index") + + if isinstance(index_entry, slice): + start, stop, idx_stride = index_entry.indices(self.shape[array_axis]) + + array_stride = self.strides[array_axis] + + new_shape.append((stop-start)//idx_stride) + new_strides.append(idx_stride*array_stride) + new_offset += array_stride*start + + index_axis += 1 + array_axis += 1 + + elif isinstance(index_entry, (int, np.integer)): + array_shape = self.shape[array_axis] + if index_entry < 0: + index_entry += array_shape + + if not (0 <= index_entry < array_shape): + raise IndexError("subindex in axis %d out of range" % index_axis) + + new_offset += self.strides[array_axis]*index_entry + + index_axis += 1 + array_axis += 1 + + elif index_entry is Ellipsis: + index_axis += 1 + + remaining_index_count = len(index) - index_axis + new_array_axis = len(self.shape) - remaining_index_count + if new_array_axis < array_axis: + raise IndexError("invalid use of ellipsis in index") + while array_axis < new_array_axis: + new_shape.append(self.shape[array_axis]) + new_strides.append(self.strides[array_axis]) + array_axis += 1 + + if seen_ellipsis: + raise IndexError("more than one ellipsis not allowed in index") + seen_ellipsis = True + else: - raise IndexError("invalid subindex in axis %d" % i) + raise IndexError("invalid subindex in axis %d" % index_axis) + + while array_axis < len(self.shape): + new_shape.append(self.shape[array_axis]) + new_strides.append(self.strides[array_axis]) - while i + 1 < len(self.shape): - i += 1 - new_shape.append(self.shape[i]) - new_strides.append(self.strides[i]) + array_axis += 1 return self._new_with_changes( data=self.base_data, -- GitLab