diff --git a/pyopencl/array.py b/pyopencl/array.py index 874ae92c47adf72ebcf678705745feebd8700ce6..aa06d065c0fbd5cea1842181237dd3d3dc2212f7 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2238,7 +2238,7 @@ def arange(queue, *args, **kwargs): # }}} -# {{{ take/put/concatenate/diff +# {{{ take/put/concatenate/diff/(h?stack) @elwise_kernel_runner def _take(result, ary, indices): @@ -2579,6 +2579,38 @@ def hstack(arrays, queue=None): return result + +def stack(arrays, axis=0, queue=None): + if not arrays: + raise ValueError("need at least one array to stack") + + if axis != 0: + # pyopencl.Array.__setitem__ does not support non-contiguous assignments + raise NotImplementedError("axis!=0 not implemented") + + if queue is None: + for ary in arrays: + if ary.queue is not None: + queue = ary.queue + break + + if not all(ary.shape == arrays[0].shape for ary in arrays[1:]): + raise ValueError("arrays must have the same shape") + + if not (0 <= axis <= arrays[0].ndim): + raise ValueError("invalid axis") + + input_shape = arrays[0].shape + input_ndim = arrays[0].ndim + result_shape = input_shape[:axis] + (len(arrays),) + input_shape[axis:] + result = empty(queue, result_shape, np.result_type(*(ary.dtype for ary in + arrays))) + for i, ary in enumerate(arrays): + idx = (slice(None),)*axis + (i,) + (slice(None),)*(input_ndim-axis) + result[idx] = ary + + return result + # }}}