diff --git a/test/test_array.py b/test/test_array.py index b2a5f1eaf45e7a78998721171e2003ec73cbee85..a63e5f3ef78ab0049874213cec9b734fdec62703 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1519,25 +1519,29 @@ def test_str_without_queue(ctx_factory): print(repr(a)) +@pytest.mark.parametrize("order", ("F", "C")) @pytest.mark.parametrize("input_dims", (1, 2, 3)) -def test_stack(ctx_factory, input_dims): +def test_stack(ctx_factory, input_dims, order): # Replicates pytato/test/test_codegen.py::test_stack import pyopencl.array as cla cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) shape = (2, 2, 2)[:input_dims] + axis = -1 if order == "F" else 0 from numpy.random import default_rng rng = default_rng() x_in = rng.random(size=shape) y_in = rng.random(size=shape) + x_in = x_in if order == "C" else np.asfortranarray(x_in) + y_in = y_in if order == "C" else np.asfortranarray(y_in) x = cla.to_device(queue, x_in) y = cla.to_device(queue, y_in) - np.testing.assert_allclose(cla.stack((x, y)).get(), - np.stack((x_in, y_in))) + np.testing.assert_allclose(cla.stack((x, y), axis=axis).get(), + np.stack((x_in, y_in), axis=axis)) if __name__ == "__main__":