diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 83f1684df7162f3ab60aeb790479df9807f6cb2d..34df04b60aaf57fc98ce8236b76cd0842314d130 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -47,7 +47,8 @@ from .container.traversal import ( rec_multimap_array_container, mapped_over_array_containers, multimapped_over_array_containers, - thaw, freeze) + thaw, freeze, + from_numpy, to_numpy) from .impl.pyopencl import PyOpenCLArrayContext @@ -74,6 +75,7 @@ __all__ = ( "mapped_over_array_containers", "multimapped_over_array_containers", "thaw", "freeze", + "from_numpy", "to_numpy", "PyOpenCLArrayContext", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index ab8a4ff8055e7893bcd390f0e1a75ba34156bf23..fce37a6b5c2bfa53d30bc018da1d674d591af6b5 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -15,6 +15,11 @@ Freezing and thawing ~~~~~~~~~~~~~~~~~~~~ .. autofunction:: freeze .. autofunction:: thaw + +Numpy conversion +~~~~~~~~~~~~~~~~ +.. autofunction:: from_numpy +.. autofunction:: to_numpy """ __copyright__ = """ @@ -43,6 +48,9 @@ THE SOFTWARE. from typing import Any, Callable from functools import update_wrapper, partial, singledispatch + +import numpy as np + from arraycontext.container import (is_array_container, serialize_container, deserialize_container) @@ -265,4 +273,34 @@ def thaw(ary, actx): # }}} + +# {{{ numpy conversion + +def from_numpy(ary, actx): + """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` + to the base array type of :class:`~arraycontext.ArrayContext`. + + The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. + """ + def _from_numpy(subary): + if isinstance(subary, np.ndarray) and subary.dtype != "O": + return actx.from_numpy(subary) + elif is_array_container(subary): + return map_array_container(_from_numpy, subary) + else: + raise TypeError(f"unrecognized array type: '{type(subary).__name__}'") + + return _from_numpy(ary) + + +def to_numpy(ary, actx): + """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to + :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. + + The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`. + """ + return rec_map_array_container(actx.to_numpy, ary) + +# }}} + # vim: foldmethod=marker diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index a8bdcb79377e9b262db0f5ef2597ad94a6a03aa5..f1f78a16b1b7e35bb2570f86cced027ab7b5d286 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -204,7 +204,7 @@ def test_array_context_np_like(actx_factory, sym_name, n_args): # }}} -# {{{ Array manipulations +# {{{ array manipulations def test_actx_stack(actx_factory): actx = actx_factory() @@ -589,6 +589,39 @@ def test_container_norm(actx_factory, ord): # }}} +# {{{ test from_numpy and to_numpy + +def test_numpy_conversion(actx_factory): + actx = actx_factory() + + ac = MyContainer( + name="test_numpy_conversion", + mass=np.random.rand(42), + momentum=make_obj_array([np.random.rand(42) for _ in range(3)]), + enthalpy=np.random.rand(42), + ) + + from arraycontext import from_numpy, to_numpy + ac_actx = from_numpy(ac, actx) + ac_roundtrip = to_numpy(ac_actx, actx) + + assert np.allclose(ac.mass, ac_roundtrip.mass) + assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0]) + + from dataclasses import replace + ac_with_cl = replace(ac, enthalpy=ac_actx.mass) + with pytest.raises(TypeError): + from_numpy(ac_with_cl, actx) + + with pytest.raises(TypeError): + from_numpy(ac_actx, actx) + + with pytest.raises(ValueError): + to_numpy(ac, actx) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: