From 77e6ca0eaa4c191bfbd18e461dfaf2d05430f2b8 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 24 May 2021 16:05:54 -0500
Subject: [PATCH] add recursive to_numpy and from_numpy

---
 arraycontext/__init__.py            |  4 ++-
 arraycontext/container/traversal.py | 38 +++++++++++++++++++++++++++++
 test/test_arraycontext.py           | 35 +++++++++++++++++++++++++-
 3 files changed, 75 insertions(+), 2 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 83f1684..34df04b 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -47,7 +47,8 @@ from .container.traversal import (
         rec_multimap_array_container,
         mapped_over_array_containers,
         multimapped_over_array_containers,
-        thaw, freeze)
+        thaw, freeze,
+        from_numpy, to_numpy)
 
 from .impl.pyopencl import PyOpenCLArrayContext
 
@@ -74,6 +75,7 @@ __all__ = (
         "mapped_over_array_containers",
         "multimapped_over_array_containers",
         "thaw", "freeze",
+        "from_numpy", "to_numpy",
 
         "PyOpenCLArrayContext",
 
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index ab8a4ff..fce37a6 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -15,6 +15,11 @@ Freezing and thawing
 ~~~~~~~~~~~~~~~~~~~~
 .. autofunction:: freeze
 .. autofunction:: thaw
+
+Numpy conversion
+~~~~~~~~~~~~~~~~
+.. autofunction:: from_numpy
+.. autofunction:: to_numpy
 """
 
 __copyright__ = """
@@ -43,6 +48,9 @@ THE SOFTWARE.
 
 from typing import Any, Callable
 from functools import update_wrapper, partial, singledispatch
+
+import numpy as np
+
 from arraycontext.container import (is_array_container,
         serialize_container, deserialize_container)
 
@@ -265,4 +273,34 @@ def thaw(ary, actx):
 
 # }}}
 
+
+# {{{ numpy conversion
+
+def from_numpy(ary, actx):
+    """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
+    to the base array type of :class:`~arraycontext.ArrayContext`.
+
+    The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
+    """
+    def _from_numpy(subary):
+        if isinstance(subary, np.ndarray) and subary.dtype != "O":
+            return actx.from_numpy(subary)
+        elif is_array_container(subary):
+            return map_array_container(_from_numpy, subary)
+        else:
+            raise TypeError(f"unrecognized array type: '{type(subary).__name__}'")
+
+    return _from_numpy(ary)
+
+
+def to_numpy(ary, actx):
+    """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to
+    :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*.
+
+    The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`.
+    """
+    return rec_map_array_container(actx.to_numpy, ary)
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index a8bdcb7..f1f78a1 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -204,7 +204,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args):
 # }}}
 
 
-# {{{ Array manipulations
+# {{{ array manipulations
 
 def test_actx_stack(actx_factory):
     actx = actx_factory()
@@ -589,6 +589,39 @@ def test_container_norm(actx_factory, ord):
 # }}}
 
 
+# {{{ test from_numpy and to_numpy
+
+def test_numpy_conversion(actx_factory):
+    actx = actx_factory()
+
+    ac = MyContainer(
+            name="test_numpy_conversion",
+            mass=np.random.rand(42),
+            momentum=make_obj_array([np.random.rand(42) for _ in range(3)]),
+            enthalpy=np.random.rand(42),
+            )
+
+    from arraycontext import from_numpy, to_numpy
+    ac_actx = from_numpy(ac, actx)
+    ac_roundtrip = to_numpy(ac_actx, actx)
+
+    assert np.allclose(ac.mass, ac_roundtrip.mass)
+    assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0])
+
+    from dataclasses import replace
+    ac_with_cl = replace(ac, enthalpy=ac_actx.mass)
+    with pytest.raises(TypeError):
+        from_numpy(ac_with_cl, actx)
+
+    with pytest.raises(TypeError):
+        from_numpy(ac_actx, actx)
+
+    with pytest.raises(ValueError):
+        to_numpy(ac, actx)
+
+# }}}
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab