diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index f1f9762c42de781508f0dc161be1e6788fbe2bce..6610494f8b97a3dba8a8976735ce872723bfcab7 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -58,6 +58,7 @@ from .container.traversal import ( rec_map_reduce_array_container, rec_multimap_reduce_array_container, thaw, freeze, + flatten, unflatten, from_numpy, to_numpy, flatten_to_numpy, unflatten_from_numpy) @@ -93,6 +94,7 @@ __all__ = ( "map_reduce_array_container", "multimap_reduce_array_container", "rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "thaw", "freeze", + "flatten", "unflatten", "from_numpy", "to_numpy", "flatten_to_numpy", "unflatten_from_numpy", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index acf7ff2051561317676f7606bbb8bfee58b210a2..c3bf32fe1ec6bc941cf52f26030d36aef952eaf4 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -23,6 +23,11 @@ Freezing and thawing .. autofunction:: freeze .. autofunction:: thaw +Flattening and unflattening +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: flatten +.. autofunction:: unflatten + Numpy conversion ~~~~~~~~~~~~~~~~ .. autofunction:: from_numpy @@ -64,6 +69,7 @@ from arraycontext.context import ArrayContext from arraycontext.container import ( ContainerT, ArrayOrContainerT, is_array_container, serialize_container, deserialize_container) +from pytools import memoize_in # {{{ array container traversal helpers @@ -495,6 +501,97 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: # }}} +# {{{ flatten / unflatten + +def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: + """Convert all arrays in the :class:`~arraycontext.ArrayContainer` + into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`. + + The operation requires :attr:`arraycontext.ArrayContext.np` to have + ``ravel`` and ``concatenate`` methods implemented. The order in which the + individual leaf arrays appear in the final array is dependent on the order + given by :func:`~arraycontext.serialize_container`. + """ + @memoize_in(actx, (flatten, "ravel_prg")) + def _ravel_prg(shape: Tuple[int, ...]) -> Any: + raise NotImplementedError + + def _flatten(subary: ArrayOrContainerT) -> None: + try: + iterable = serialize_container(subary) + except TypeError: + try: + flat_subary = actx.np.ravel(subary, order="A") + except ValueError: + flat_subary = actx.call_loopy( + _ravel_prg(subary.shape), ary=subary) + + result.append(flat_subary) + else: + for _, isubary in iterable: + _flatten(isubary) + + result: List[Any] = [] + _flatten(ary) + + return actx.np.concatenate(result) + + +def unflatten( + template: ArrayOrContainerT, ary: Any, + actx: ArrayContext) -> ArrayOrContainerT: + """Unflatten an array produced by :func:`flatten` back into an + :class:`~arraycontext.ArrayContainer`. + + The order and sizes of each slice into *ary* are determined by the + array container *template*. + """ + # NOTE: https://github.com/python/mypy/issues/7057 + offset = 0 + + @memoize_in(actx, (unflatten, "reshape_prg")) + def _reshape_prg(shape: Tuple[int, ...]) -> Any: + raise NotImplementedError + + def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT: + nonlocal offset + + try: + iterable = serialize_container(template_subary) + except TypeError: + # NOTE: the max is needed to handle device scalars with size == 0 + offset += max(1, template_subary.size) + if offset > ary.size: + raise ValueError("'template' and 'ary' sizes do not match") + + flat_subary = ary[offset - template_subary.size:offset] + try: + subary = actx.np.reshape(flat_subary, template_subary.shape) + except ValueError: + subary = actx.call_loopy( + _reshape_prg(template_subary.shape), ary=flat_subary) + + return actx.np.astype(subary, template_subary.dtype) + else: + return deserialize_container(template_subary, [ + (key, _unflatten(isubary)) for key, isubary in iterable + ]) + + if not isinstance(ary, actx.array_types): + raise TypeError("'ary' does not have a type supported by the provided " + f"array context: got '{type(ary).__name__}', expected one of " + f"{actx.array_types}") + + if ary.ndim != 1: + raise ValueError( + "only one dimensional arrays can be unflattened: " + f"'ary' has shape {ary.shape}") + + return _unflatten(template) + +# }}} + + # {{{ numpy conversion def from_numpy(ary: Any, actx: ArrayContext) -> Any: diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 0ba81e12958991d650032dc279f30e8286293fba..c5b57ef58b49ed2bdb95dd3547839a5222ee6eeb 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -265,6 +265,11 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): result = result.get()[()] return result + def astype(self, a, dtype): + return rec_map_array_container( + lambda x: x.astype(dtype, queue=self._array_context.queue), + a) + # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index dfa238d03f36da80f895911f5a7f028499a7e661..5f9cfc7848c5bf218571210f5787861ff267e944 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -159,7 +159,7 @@ class DOFArray: @serialize_container.register(DOFArray) def _serialize_dof_container(ary: DOFArray): - return enumerate(ary.data) + return list(enumerate(ary.data)) @deserialize_container.register(DOFArray) @@ -877,6 +877,27 @@ def test_container_norm(actx_factory, ord): # }}} +# {{{ test flatten and unflatten + +def test_flatten_array_container(actx_factory): + actx = actx_factory() + if not hasattr(actx.np, "astype"): + pytest.skip(f"'astype' not implemented on '{type(actx).__name__}'") + + from arraycontext import flatten, unflatten + arys = _get_test_containers(actx, size=512) + for ary in arys: + flat = flatten(ary, actx) + assert flat.ndim == 1 + + ary_roundtrip = unflatten(ary, flat, actx) + assert actx.to_numpy( + actx.np.linalg.norm(ary - ary_roundtrip) + ) < 1.0e-15 + +# }}} + + # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory):