From c56e8f5860a192024faaab1ef2f74411e2ef43e8 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 28 Sep 2021 09:49:37 -0500
Subject: [PATCH] hardcode flatten and unflatten in c order

---
 arraycontext/container/traversal.py      | 5 +++--
 arraycontext/impl/pyopencl/fake_numpy.py | 6 ++++--
 arraycontext/impl/pytato/fake_numpy.py   | 5 +++--
 3 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index cf9640f..db3633d 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -526,7 +526,7 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
                         f"got {subary.dtype}, expected {common_dtype}")
 
             try:
-                flat_subary = actx.np.ravel(subary, order="A")
+                flat_subary = actx.np.ravel(subary, order="C")
             except ValueError as exc:
                 # NOTE: we can't do much if the array context fails to ravel,
                 # since it is the one responsible for the actual memory layout
@@ -580,7 +580,8 @@ def unflatten(
 
             flat_subary = ary[offset - template_subary.size:offset]
             try:
-                subary = actx.np.reshape(flat_subary, template_subary.shape)
+                subary = actx.np.reshape(flat_subary,
+                        template_subary.shape, order="C")
             except ValueError as exc:
                 # NOTE: we can't do much if the array context fails to reshape,
                 # since it is the one responsible for the actual memory layout
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index c5b57ef..c60f33d 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -172,8 +172,10 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
                     queue=self._array_context.queue),
                 *arrays)
 
-    def reshape(self, a, newshape):
-        return cl_array.reshape(a, newshape)
+    def reshape(self, a, newshape, order="C"):
+        return rec_map_array_container(
+                lambda ary: ary.reshape(newshape, order=order),
+                a)
 
     def concatenate(self, arrays, axis=0):
         return cl_array.concatenate(
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index 01efaec..62b5e20 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -64,8 +64,9 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
         return super().__getattr__(name)
 
-    def reshape(self, a, newshape):
-        return rec_multimap_array_container(pt.reshape, a, newshape)
+    def reshape(self, a, newshape, order="C"):
+        return rec_multimap_array_container(
+                partial(pt.reshape, order=order), a, newshape)
 
     def transpose(self, a, axes=None):
         return rec_multimap_array_container(pt.transpose, a, axes)
-- 
GitLab