diff --git a/test/test_array.py b/test/test_array.py index eeb55ff72edf36fa76388c8d5eb3e3db1763ad94..b2a5f1eaf45e7a78998721171e2003ec73cbee85 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1519,6 +1519,27 @@ def test_str_without_queue(ctx_factory): print(repr(a)) +@pytest.mark.parametrize("input_dims", (1, 2, 3)) +def test_stack(ctx_factory, input_dims): + # Replicates pytato/test/test_codegen.py::test_stack + import pyopencl.array as cla + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + shape = (2, 2, 2)[:input_dims] + + from numpy.random import default_rng + rng = default_rng() + x_in = rng.random(size=shape) + y_in = rng.random(size=shape) + + x = cla.to_device(queue, x_in) + y = cla.to_device(queue, y_in) + + np.testing.assert_allclose(cla.stack((x, y)).get(), + np.stack((x_in, y_in))) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])