diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index 0cd3f64406689d2be45468fd53b0df7a3a3ed04e..3dd5f3f4e79b5c75edd8714fae643a4a82ed2a78 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -38,7 +38,8 @@ from pytools.tag import Tag from arraycontext.metadata import FirstAxisIsElementsTag from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace -from arraycontext.container.traversal import rec_multimap_array_container +from arraycontext.container.traversal import (rec_multimap_array_container, + rec_map_array_container) from arraycontext.container import serialize_container, is_array_container from arraycontext.context import ArrayContext @@ -141,6 +142,28 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): self._array_context.allocator ) + def ravel(self, a, order="C"): + def _rec_ravel(a): + if order in "FC": + return a.reshape(-1, order=order) + elif order == "A": + # TODO: upstream this to pyopencl.array + if a.flags.f_contiguous: + return a.reshape(-1, order="F") + elif a.flags.c_contiguous: + return a.reshape(-1, order="C") + else: + raise ValueError("For `order='A'`, array should be either" + " F-contiguous or C-contiguous.") + elif order == "K": + raise NotImplementedError("PyOpenCLArrayContext.np.ravel not " + "implemented for 'order=K'") + else: + raise ValueError("`order` can be one of 'F', 'C', 'A' or 'K'. " + f"(got {order})") + + return rec_map_array_container(_rec_ravel, a) + # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 7483246ff29e1b67945007e65edcb6adff03f3b6..d8f9ccbe3e786bd7def309469588c063bea761d9 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -235,6 +235,17 @@ def test_actx_reshape(actx_factory): actx, lambda _np, *_args: _np.reshape(*_args), (np.random.randn(2, 3), new_shape)) + +def test_actx_ravel(actx_factory): + from numpy.random import default_rng + actx = actx_factory() + rng = default_rng() + ndim = rng.integers(low=1, high=6) + shape = tuple(rng.integers(2, 7, ndim)) + + assert_close_to_numpy(actx, lambda _np, ary: _np.ravel(ary), + (rng.random(shape),)) + # }}}