From 917c656d729aca4305d99dd4b72eac0943de5752 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sat, 25 Sep 2021 16:44:20 -0500
Subject: [PATCH] add flatten to numpy for an entire container

---
 arraycontext/__init__.py            |  4 +-
 arraycontext/container/traversal.py | 67 +++++++++++++++++++++++++++++
 test/test_arraycontext.py           | 36 ++++++++++++++--
 3 files changed, 103 insertions(+), 4 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 6120302..f1f9762 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -58,7 +58,8 @@ from .container.traversal import (
         rec_map_reduce_array_container,
         rec_multimap_reduce_array_container,
         thaw, freeze,
-        from_numpy, to_numpy)
+        from_numpy, to_numpy,
+        flatten_to_numpy, unflatten_from_numpy)
 
 from .impl.pyopencl import PyOpenCLArrayContext
 from .impl.pytato import PytatoPyOpenCLArrayContext
@@ -93,6 +94,7 @@ __all__ = (
         "rec_map_reduce_array_container", "rec_multimap_reduce_array_container",
         "thaw", "freeze",
         "from_numpy", "to_numpy",
+        "flatten_to_numpy", "unflatten_from_numpy",
 
         "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
 
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index ea5fce9..acf7ff2 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -27,6 +27,8 @@ Numpy conversion
 ~~~~~~~~~~~~~~~~
 .. autofunction:: from_numpy
 .. autofunction:: to_numpy
+.. autofunction:: flatten_to_numpy
+.. autofunction:: unflatten_from_numpy
 """
 
 __copyright__ = """
@@ -520,6 +522,71 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any:
     """
     return rec_map_array_container(actx.to_numpy, ary)
 
+
+def flatten_to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> np.ndarray:
+    """Convert all arrays in the :class:`~arraycontext.ArrayContainer`
+    to host :mod:`numpy` arrays, flatten them using :func:`~numpy.ravel`
+    and concatenate them into a single :class:`~numpy.ndarray`.
+
+    The order in which the individual leaf arrays appear in the final array is
+    dependent on the order given by :func:`~arraycontext.serialize_container`.
+    """
+    def _flatten_to_numpy(subary: ArrayOrContainerT) -> None:
+        try:
+            iterable = serialize_container(subary)
+        except TypeError:
+            result.append(actx.to_numpy(subary).ravel())
+        else:
+            for _, isubary in iterable:
+                _flatten_to_numpy(isubary)
+
+    result: List[np.ndarray] = []
+    _flatten_to_numpy(ary)
+
+    return np.concatenate(result)
+
+
+def unflatten_from_numpy(
+        template: ArrayOrContainerT, ary: np.ndarray,
+        actx: ArrayContext) -> ArrayOrContainerT:
+    """Unflatten an :class:`~numpy.ndarray` produced by :func:`flatten_to_numpy`
+    back into an :class:`~arraycontext.ArrayContainer`.
+
+    The order and sizes of each slice into *ary* are determined by the
+    array container *template*.
+    """
+    # NOTE: https://github.com/python/mypy/issues/7057
+    offset = 0
+
+    def _unflatten_from_numpy(subary: ArrayOrContainerT) -> ArrayOrContainerT:
+        nonlocal offset
+
+        try:
+            iterable = serialize_container(subary)
+        except TypeError:
+            # NOTE: the max is needed to handle device scalars with size == 0
+            offset += max(1, subary.size)
+            if offset > ary.size:
+                raise ValueError("'template' and 'ary' sizes do not match")
+
+            # FIXME: subary can be F-contiguous and ary will always be C-contiguous
+            return actx.from_numpy(
+                    ary[offset - subary.size:offset]
+                    .astype(subary.dtype, copy=False)
+                    .reshape(subary.shape)
+                    )
+        else:
+            return deserialize_container(subary, [
+                (key, _unflatten_from_numpy(isubary)) for key, isubary in iterable
+                ])
+
+    if ary.ndim != 1:
+        raise ValueError(
+                "only one dimensional arrays can be unflattened: "
+                f"'ary' has shape {ary.shape}")
+
+    return _unflatten_from_numpy(template)
+
 # }}}
 
 # vim: foldmethod=marker
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 8800098..dfa238d 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -882,13 +882,16 @@ def test_container_norm(actx_factory, ord):
 def test_numpy_conversion(actx_factory):
     actx = actx_factory()
 
+    nelements = 42
     ac = MyContainer(
             name="test_numpy_conversion",
-            mass=np.random.rand(42),
-            momentum=make_obj_array([np.random.rand(42) for _ in range(3)]),
-            enthalpy=np.random.rand(42),
+            mass=np.random.rand(nelements, nelements),
+            momentum=make_obj_array([np.random.rand(nelements) for _ in range(3)]),
+            enthalpy=np.array(np.random.rand()),
             )
 
+    # {{{ to/from_numpy
+
     from arraycontext import from_numpy, to_numpy
     ac_actx = from_numpy(ac, actx)
     ac_roundtrip = to_numpy(ac_actx, actx)
@@ -907,6 +910,33 @@ def test_numpy_conversion(actx_factory):
     with pytest.raises(ValueError):
         to_numpy(ac, actx)
 
+    # }}}
+
+    # {{{ un/flatten
+
+    from arraycontext import flatten_to_numpy, unflatten_from_numpy
+    ac_flat = flatten_to_numpy(ac_actx, actx)
+    assert ac_flat.shape == (nelements**2 + 3 * nelements + 1,)
+
+    ac_roundtrip = unflatten_from_numpy(ac_actx, ac_flat, actx)
+    for name in ("mass", "momentum", "enthalpy"):
+        field = getattr(ac_actx, name)
+        field_roundtrip = getattr(ac_roundtrip, name)
+
+        assert field.dtype == field_roundtrip.dtype
+        assert field.shape == field_roundtrip.shape
+        assert np.linalg.norm(
+                np.linalg.norm(to_numpy(field - field_roundtrip, actx))
+                ) < 1.0e-15
+
+    with pytest.raises(ValueError):
+        unflatten_from_numpy(ac_actx, ac_flat[:-12], actx)
+
+    with pytest.raises(ValueError):
+        unflatten_from_numpy(ac_actx, ac_flat.reshape(2, -1), actx)
+
+    # }}}
+
 # }}}
 
 
-- 
GitLab