diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 57e82ff02f4cf4bbdc3a921bff48928acffb23e3..85c905e3fe1898b9a81fa8c41cdac8d0a2bfc1b3 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -866,9 +866,19 @@ class ArrayBase(Record): if not sep_shape: return None + def unwrap_1d_indices(idx): + # This allows these indices to work on Python sequences, too, not + # just numpy arrays. + + if len(idx) == 1: + return idx[0] + else: + return idx + from pytools import indices_in_shape return [ - (i, self.name + "".join("_s%d" % sub_i for sub_i in i)) + (unwrap_1d_indices(i), + self.name + "".join("_s%d" % sub_i for sub_i in i)) for i in indices_in_shape(sep_shape)] # }}}