diff --git a/pyopencl/array.py b/pyopencl/array.py index aa06d065c0fbd5cea1842181237dd3d3dc2212f7..97dc28dc405fd23941c4a80da6111806205639fc 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2581,12 +2581,21 @@ def hstack(arrays, queue=None): def stack(arrays, axis=0, queue=None): + """ + Join a sequence of arrays along a new axis. + + :arg arrays: A sequnce of :class:`Array`. + :arg axis: Index of the dimension of the new axis in the result array. + Can be -1, for the new axis to be last dimension. + + :returns: :class:`Array` + """ if not arrays: raise ValueError("need at least one array to stack") - if axis != 0: - # pyopencl.Array.__setitem__ does not support non-contiguous assignments - raise NotImplementedError("axis!=0 not implemented") + input_shape = arrays[0].shape + input_ndim = arrays[0].ndim + axis = input_ndim if axis == -1 else axis if queue is None: for ary in arrays: @@ -2594,17 +2603,28 @@ def stack(arrays, axis=0, queue=None): queue = ary.queue break - if not all(ary.shape == arrays[0].shape for ary in arrays[1:]): + if not all(ary.shape == input_shape for ary in arrays[1:]): raise ValueError("arrays must have the same shape") - if not (0 <= axis <= arrays[0].ndim): + if not (0 <= axis <= input_ndim): raise ValueError("invalid axis") - input_shape = arrays[0].shape - input_ndim = arrays[0].ndim + if (axis == 0 and not all(ary.flags.c_contiguous + for ary in arrays)): + # pyopencl.Array.__setitem__ does not support non-contiguous assignments + raise NotImplementedError + + if (axis == input_ndim and not all(ary.flags.f_contiguous + for ary in arrays)): + # pyopencl.Array.__setitem__ does not support non-contiguous assignments + raise NotImplementedError + result_shape = input_shape[:axis] + (len(arrays),) + input_shape[axis:] - result = empty(queue, result_shape, np.result_type(*(ary.dtype for ary in - arrays))) + result = empty(queue, result_shape, np.result_type(*(ary.dtype + for ary in arrays)), + # TODO: reconsider once arrays support non-contiguous + # assignments + order="C" if axis == 0 else "F") for i, ary in enumerate(arrays): idx = (slice(None),)*axis + (i,) + (slice(None),)*(input_ndim-axis) result[idx] = ary