From a0447cc4532d235c8a8e550ae893eb1ea7c154fb Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com>
Date: Mon, 14 Jun 2021 12:15:00 -0500
Subject: [PATCH] Implement PyOpenCLArrayContext.ravel (#34)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* implement PyOpenCLArrayContext.ravel

* fixup! implement PyOpenCLArrayContext.ravel

- order="K" was only correct for >0 stride values, instead renamed it to
  order="A" as per numpy.

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 arraycontext/impl/pyopencl.py | 25 ++++++++++++++++++++++++-
 test/test_arraycontext.py     | 11 +++++++++++
 2 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py
index 0cd3f64..3dd5f3f 100644
--- a/arraycontext/impl/pyopencl.py
+++ b/arraycontext/impl/pyopencl.py
@@ -38,7 +38,8 @@ from pytools.tag import Tag
 from arraycontext.metadata import FirstAxisIsElementsTag
 from arraycontext.fake_numpy import \
         BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
-from arraycontext.container.traversal import rec_multimap_array_container
+from arraycontext.container.traversal import (rec_multimap_array_container,
+                                              rec_map_array_container)
 from arraycontext.container import serialize_container, is_array_container
 from arraycontext.context import ArrayContext
 
@@ -141,6 +142,28 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
             self._array_context.allocator
         )
 
+    def ravel(self, a, order="C"):
+        def _rec_ravel(a):
+            if order in "FC":
+                return a.reshape(-1, order=order)
+            elif order == "A":
+                # TODO: upstream this to pyopencl.array
+                if a.flags.f_contiguous:
+                    return a.reshape(-1, order="F")
+                elif a.flags.c_contiguous:
+                    return a.reshape(-1, order="C")
+                else:
+                    raise ValueError("For `order='A'`, array should be either"
+                                     " F-contiguous or C-contiguous.")
+            elif order == "K":
+                raise NotImplementedError("PyOpenCLArrayContext.np.ravel not "
+                                          "implemented for 'order=K'")
+            else:
+                raise ValueError("`order` can be one of 'F', 'C', 'A' or 'K'. "
+                                 f"(got {order})")
+
+        return rec_map_array_container(_rec_ravel, a)
+
 # }}}
 
 
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 7483246..d8f9ccb 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -235,6 +235,17 @@ def test_actx_reshape(actx_factory):
                 actx, lambda _np, *_args: _np.reshape(*_args),
                 (np.random.randn(2, 3), new_shape))
 
+
+def test_actx_ravel(actx_factory):
+    from numpy.random import default_rng
+    actx = actx_factory()
+    rng = default_rng()
+    ndim = rng.integers(low=1, high=6)
+    shape = tuple(rng.integers(2, 7, ndim))
+
+    assert_close_to_numpy(actx, lambda _np, ary: _np.ravel(ary),
+                          (rng.random(shape),))
+
 # }}}
 
 
-- 
GitLab