diff --git a/doc/array.rst b/doc/array.rst index f44ce94545469656e92153f0c3aeda86292a6097..dc9f3974729707d2768c2f2d49f6a86c416e36e3 100644 --- a/doc/array.rst +++ b/doc/array.rst @@ -149,6 +149,7 @@ Constructing :class:`Array` Instances .. autofunction:: arange .. autofunction:: take .. autofunction:: concatenate +.. autofunction:: stack Manipulating :class:`Array` instances ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pyopencl/array.py b/pyopencl/array.py index 874ae92c47adf72ebcf678705745feebd8700ce6..97dc28dc405fd23941c4a80da6111806205639fc 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -2238,7 +2238,7 @@ def arange(queue, *args, **kwargs): # }}} -# {{{ take/put/concatenate/diff +# {{{ take/put/concatenate/diff/(h?stack) @elwise_kernel_runner def _take(result, ary, indices): @@ -2579,6 +2579,58 @@ def hstack(arrays, queue=None): return result + +def stack(arrays, axis=0, queue=None): + """ + Join a sequence of arrays along a new axis. + + :arg arrays: A sequnce of :class:`Array`. + :arg axis: Index of the dimension of the new axis in the result array. + Can be -1, for the new axis to be last dimension. + + :returns: :class:`Array` + """ + if not arrays: + raise ValueError("need at least one array to stack") + + input_shape = arrays[0].shape + input_ndim = arrays[0].ndim + axis = input_ndim if axis == -1 else axis + + if queue is None: + for ary in arrays: + if ary.queue is not None: + queue = ary.queue + break + + if not all(ary.shape == input_shape for ary in arrays[1:]): + raise ValueError("arrays must have the same shape") + + if not (0 <= axis <= input_ndim): + raise ValueError("invalid axis") + + if (axis == 0 and not all(ary.flags.c_contiguous + for ary in arrays)): + # pyopencl.Array.__setitem__ does not support non-contiguous assignments + raise NotImplementedError + + if (axis == input_ndim and not all(ary.flags.f_contiguous + for ary in arrays)): + # pyopencl.Array.__setitem__ does not support non-contiguous assignments + raise NotImplementedError + + result_shape = input_shape[:axis] + (len(arrays),) + input_shape[axis:] + result = empty(queue, result_shape, np.result_type(*(ary.dtype + for ary in arrays)), + # TODO: reconsider once arrays support non-contiguous + # assignments + order="C" if axis == 0 else "F") + for i, ary in enumerate(arrays): + idx = (slice(None),)*axis + (i,) + (slice(None),)*(input_ndim-axis) + result[idx] = ary + + return result + # }}} diff --git a/test/test_array.py b/test/test_array.py index eeb55ff72edf36fa76388c8d5eb3e3db1763ad94..a63e5f3ef78ab0049874213cec9b734fdec62703 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1519,6 +1519,31 @@ def test_str_without_queue(ctx_factory): print(repr(a)) +@pytest.mark.parametrize("order", ("F", "C")) +@pytest.mark.parametrize("input_dims", (1, 2, 3)) +def test_stack(ctx_factory, input_dims, order): + # Replicates pytato/test/test_codegen.py::test_stack + import pyopencl.array as cla + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + shape = (2, 2, 2)[:input_dims] + axis = -1 if order == "F" else 0 + + from numpy.random import default_rng + rng = default_rng() + x_in = rng.random(size=shape) + y_in = rng.random(size=shape) + x_in = x_in if order == "C" else np.asfortranarray(x_in) + y_in = y_in if order == "C" else np.asfortranarray(y_in) + + x = cla.to_device(queue, x_in) + y = cla.to_device(queue, y_in) + + np.testing.assert_allclose(cla.stack((x, y), axis=axis).get(), + np.stack((x_in, y_in), axis=axis)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])