From 77e6ca0eaa4c191bfbd18e461dfaf2d05430f2b8 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Mon, 24 May 2021 16:05:54 -0500 Subject: [PATCH] add recursive to_numpy and from_numpy --- arraycontext/__init__.py | 4 ++- arraycontext/container/traversal.py | 38 +++++++++++++++++++++++++++++ test/test_arraycontext.py | 35 +++++++++++++++++++++++++- 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 83f1684..34df04b 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -47,7 +47,8 @@ from .container.traversal import ( rec_multimap_array_container, mapped_over_array_containers, multimapped_over_array_containers, - thaw, freeze) + thaw, freeze, + from_numpy, to_numpy) from .impl.pyopencl import PyOpenCLArrayContext @@ -74,6 +75,7 @@ __all__ = ( "mapped_over_array_containers", "multimapped_over_array_containers", "thaw", "freeze", + "from_numpy", "to_numpy", "PyOpenCLArrayContext", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index ab8a4ff..fce37a6 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -15,6 +15,11 @@ Freezing and thawing ~~~~~~~~~~~~~~~~~~~~ .. autofunction:: freeze .. autofunction:: thaw + +Numpy conversion +~~~~~~~~~~~~~~~~ +.. autofunction:: from_numpy +.. autofunction:: to_numpy """ __copyright__ = """ @@ -43,6 +48,9 @@ THE SOFTWARE. from typing import Any, Callable from functools import update_wrapper, partial, singledispatch + +import numpy as np + from arraycontext.container import (is_array_container, serialize_container, deserialize_container) @@ -265,4 +273,34 @@ def thaw(ary, actx): # }}} + +# {{{ numpy conversion + +def from_numpy(ary, actx): + """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` + to the base array type of :class:`~arraycontext.ArrayContext`. + + The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. + """ + def _from_numpy(subary): + if isinstance(subary, np.ndarray) and subary.dtype != "O": + return actx.from_numpy(subary) + elif is_array_container(subary): + return map_array_container(_from_numpy, subary) + else: + raise TypeError(f"unrecognized array type: '{type(subary).__name__}'") + + return _from_numpy(ary) + + +def to_numpy(ary, actx): + """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to + :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. + + The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`. + """ + return rec_map_array_container(actx.to_numpy, ary) + +# }}} + # vim: foldmethod=marker diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index a8bdcb7..f1f78a1 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -204,7 +204,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args): # }}} -# {{{ Array manipulations +# {{{ array manipulations def test_actx_stack(actx_factory): actx = actx_factory() @@ -589,6 +589,39 @@ def test_container_norm(actx_factory, ord): # }}} +# {{{ test from_numpy and to_numpy + +def test_numpy_conversion(actx_factory): + actx = actx_factory() + + ac = MyContainer( + name="test_numpy_conversion", + mass=np.random.rand(42), + momentum=make_obj_array([np.random.rand(42) for _ in range(3)]), + enthalpy=np.random.rand(42), + ) + + from arraycontext import from_numpy, to_numpy + ac_actx = from_numpy(ac, actx) + ac_roundtrip = to_numpy(ac_actx, actx) + + assert np.allclose(ac.mass, ac_roundtrip.mass) + assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0]) + + from dataclasses import replace + ac_with_cl = replace(ac, enthalpy=ac_actx.mass) + with pytest.raises(TypeError): + from_numpy(ac_with_cl, actx) + + with pytest.raises(TypeError): + from_numpy(ac_actx, actx) + + with pytest.raises(ValueError): + to_numpy(ac, actx) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab