diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 61203029ca484cba9918afc54a4c547b9ecf8905..c0bca7703c04cf6391375fbfdbbc7f8bed8d4693 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -58,6 +58,7 @@ from .container.traversal import ( rec_map_reduce_array_container, rec_multimap_reduce_array_container, thaw, freeze, + flatten, unflatten, from_numpy, to_numpy) from .impl.pyopencl import PyOpenCLArrayContext @@ -92,6 +93,7 @@ __all__ = ( "map_reduce_array_container", "multimap_reduce_array_container", "rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "thaw", "freeze", + "flatten", "unflatten", "from_numpy", "to_numpy", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 8bd65d8d4bc900824a540251526ce27710a4c67c..d2f9761434d6a17fceb4aff2a4b9508b85c16f40 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -120,7 +120,11 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: r"""Serialize the array container into an iterable over its components. The order of the components and their identifiers are entirely under - the control of the container class. + the control of the container class. However, the order is required to be + deterministic, i.e. two calls to :func:`serialize_container` on + array containers of the same types with the same number of + sub-arrays must result in an iterable with the keys in the same + order. If *ary* is mutable, the serialization function is not required to ensure that the serialization result reflects the array state at the time of the diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 8858e86d1eaafd770674c3ad2b32cbf5eb90850e..8d0f9f33bb5bfa67bb6a595de9b214a6481bfbf1 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -23,6 +23,11 @@ Freezing and thawing .. autofunction:: freeze .. autofunction:: thaw +Flattening and unflattening +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: flatten +.. autofunction:: unflatten + Numpy conversion ~~~~~~~~~~~~~~~~ .. autofunction:: from_numpy @@ -493,6 +498,131 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: # }}} +# {{{ flatten / unflatten + +def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: + """Convert all arrays in the :class:`~arraycontext.ArrayContainer` + into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`. + + The operation requires :attr:`arraycontext.ArrayContext.np` to have + ``ravel`` and ``concatenate`` methods implemented. The order in which the + individual leaf arrays appear in the final array is dependent on the order + given by :func:`~arraycontext.serialize_container`. + """ + common_dtype = None + result: List[Any] = [] + + def _flatten(subary: ArrayOrContainerT) -> None: + nonlocal common_dtype + + try: + iterable = serialize_container(subary) + except TypeError: + if common_dtype is None: + common_dtype = subary.dtype + + if subary.dtype != common_dtype: + raise ValueError("arrays in container have different dtypes: " + f"got {subary.dtype}, expected {common_dtype}") + + try: + flat_subary = actx.np.ravel(subary, order="C") + except ValueError as exc: + # NOTE: we can't do much if the array context fails to ravel, + # since it is the one responsible for the actual memory layout + if hasattr(subary, "strides"): + strides_msg = f" and strides {subary.strides}" + else: + strides_msg = "" + + raise NotImplementedError( + f"'{type(actx).__name__}.np.ravel' failed to reshape " + f"an array with shape {subary.shape}{strides_msg}. " + "This functionality needs to be implemented by the " + "array context.") from exc + + result.append(flat_subary) + else: + for _, isubary in iterable: + _flatten(isubary) + + _flatten(ary) + + return actx.np.concatenate(result) + + +def unflatten( + template: ArrayOrContainerT, ary: Any, + actx: ArrayContext) -> ArrayOrContainerT: + """Unflatten an array *ary* produced by :func:`flatten` back into an + :class:`~arraycontext.ArrayContainer`. + + The order and sizes of each slice into *ary* are determined by the + array container *template*. + """ + # NOTE: https://github.com/python/mypy/issues/7057 + offset = 0 + + def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT: + nonlocal offset + + try: + iterable = serialize_container(template_subary) + except TypeError: + if (offset + template_subary.size) > ary.size: + raise ValueError("'template' and 'ary' sizes do not match: " + "'template' is too large") + + if template_subary.dtype != ary.dtype: + raise ValueError("'template' dtype does not match 'ary': " + f"got {template_subary.dtype}, expected {ary.dtype}") + + flat_subary = ary[offset:offset + template_subary.size] + try: + subary = actx.np.reshape(flat_subary, + template_subary.shape, order="C") + except ValueError as exc: + # NOTE: we can't do much if the array context fails to reshape, + # since it is the one responsible for the actual memory layout + raise NotImplementedError( + f"'{type(actx).__name__}.np.reshape' failed to reshape " + f"the flat array into shape {template_subary.shape}. " + "This functionality needs to be implemented by the " + "array context.") from exc + + if hasattr(template_subary, "strides"): + if template_subary.strides != subary.strides: + raise ValueError( + f"strides do not match template: got {subary.strides}, " + f"expected {template_subary.strides}") + + offset += template_subary.size + return subary + else: + return deserialize_container(template_subary, [ + (key, _unflatten(isubary)) for key, isubary in iterable + ]) + + if not isinstance(ary, actx.array_types): + raise TypeError("'ary' does not have a type supported by the provided " + f"array context: got '{type(ary).__name__}', expected one of " + f"{actx.array_types}") + + if ary.ndim != 1: + raise ValueError( + "only one dimensional arrays can be unflattened: " + f"'ary' has shape {ary.shape}") + + result = _unflatten(template) + if offset != ary.size: + raise ValueError("'template' and 'ary' sizes do not match: " + "'ary' is too large") + + return result + +# }}} + + # {{{ numpy conversion def from_numpy(ary: Any, actx: ArrayContext) -> Any: diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 882a2e5c51a74d0149267f4353d651f5ccc5b127..c8c1715ae77674c56bab53efc1e5b338cbf0a909 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -172,8 +172,10 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): queue=self._array_context.queue), *arrays) - def reshape(self, a, newshape): - return cl_array.reshape(a, newshape) + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) def concatenate(self, arrays, axis=0): return cl_array.concatenate( diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 9721826449b240013f51cbb44eb52980aaa2a167..dacf7271da70b0ea542744fb1f89b2e9a2d533b2 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -64,8 +64,10 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): return super().__getattr__(name) - def reshape(self, a, newshape): - return rec_multimap_array_container(pt.reshape, a, newshape) + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: pt.reshape(a, newshape, order=order), + a) def transpose(self, a, axes=None): return rec_multimap_array_container(pt.transpose, a, axes) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0ad4aade35b4ba5ee156be4bbfa871cc271947c4..0dcb1a6611f9240f4106119f07bda0f77c551398 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -159,7 +159,7 @@ class DOFArray: @serialize_container.register(DOFArray) def _serialize_dof_container(ary: DOFArray): - return enumerate(ary.data) + return list(enumerate(ary.data)) @deserialize_container.register(DOFArray) @@ -203,17 +203,27 @@ def randn(shape, dtype): rng = np.random.default_rng() dtype = np.dtype(dtype) + if shape == 0: + ashape = 1 + else: + ashape = shape + if dtype.kind == "c": dtype = np.dtype(f"<f{dtype.itemsize // 2}") - return rng.standard_normal(shape, dtype) \ - + 1j * rng.standard_normal(shape, dtype) + r = rng.standard_normal(ashape, dtype) \ + + 1j * rng.standard_normal(ashape, dtype) elif dtype.kind == "f": - return rng.standard_normal(shape, dtype) + r = rng.standard_normal(ashape, dtype) elif dtype.kind == "i": - return rng.integers(0, 128, shape, dtype) + r = rng.integers(0, 512, ashape, dtype) else: raise TypeError(dtype.kind) + if shape == 0: + return np.array(r[0]) + + return r + def assert_close_to_numpy(actx, op, args): assert np.allclose( @@ -672,11 +682,14 @@ class MyContainerDOFBcast: return self.mass.array_context -def _get_test_containers(actx, ambient_dim=2, size=50_000): - if size == 0: - x = DOFArray(actx, (actx.from_numpy(np.array(np.random.randn())),)) - else: - x = DOFArray(actx, (actx.from_numpy(np.random.randn(size)),)) +def _get_test_containers(actx, ambient_dim=2, shapes=50_000): + from numbers import Number + if isinstance(shapes, (Number, tuple)): + shapes = [shapes] + + x = DOFArray(actx, tuple([ + actx.from_numpy(randn(shape, np.float64)) + for shape in shapes])) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter dataclass_of_dofs = MyContainer( @@ -705,7 +718,7 @@ def _get_test_containers(actx, ambient_dim=2, size=50_000): def test_container_scalar_map(actx_factory): actx = actx_factory() - arys = _get_test_containers(actx, size=0) + arys = _get_test_containers(actx, shapes=0) arys += (np.pi,) from arraycontext import ( @@ -877,16 +890,76 @@ def test_container_norm(actx_factory, ord): # }}} +# {{{ test flatten and unflatten + +@pytest.mark.parametrize("shapes", [ + 0, # tests device scalars when flattening + 512, + [(128, 67)], + [(127, 67), (18, 0)], # tests 0-sized arrays + [(64, 7), (154, 12)] + ]) +def test_flatten_array_container(actx_factory, shapes): + if np.prod(shapes) == 0: + # https://github.com/inducer/loopy/pull/497 + # NOTE: only fails for the pytato array context at the moment + pytest.xfail("strides do not match in subary") + + actx = actx_factory() + + from arraycontext import flatten, unflatten + arys = _get_test_containers(actx, shapes=shapes) + + for ary in arys: + flat = flatten(ary, actx) + assert flat.ndim == 1 + + ary_roundtrip = unflatten(ary, flat, actx) + + from arraycontext import rec_multimap_reduce_array_container + assert rec_multimap_reduce_array_container( + np.prod, + lambda x, y: x.shape == y.shape, + ary, ary_roundtrip) + + assert actx.to_numpy( + actx.np.linalg.norm(ary - ary_roundtrip) + ) < 1.0e-15 + + +def test_flatten_array_container_failure(actx_factory): + actx = actx_factory() + + from arraycontext import flatten, unflatten + ary = _get_test_containers(actx, shapes=512)[0] + flat_ary = flatten(ary, actx) + + with pytest.raises(TypeError): + # cannot unflatten from a numpy array + unflatten(ary, actx.to_numpy(flat_ary), actx) + + with pytest.raises(ValueError): + # cannot unflatten non-flat arrays + unflatten(ary, flat_ary.reshape(2, -1), actx) + + with pytest.raises(ValueError): + # cannot unflatten partially + unflatten(ary, flat_ary[:-1], actx) + +# }}} + + # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory): actx = actx_factory() + nelements = 42 ac = MyContainer( name="test_numpy_conversion", - mass=np.random.rand(42), - momentum=make_obj_array([np.random.rand(42) for _ in range(3)]), - enthalpy=np.random.rand(42), + mass=np.random.rand(nelements, nelements), + momentum=make_obj_array([np.random.rand(nelements) for _ in range(3)]), + enthalpy=np.array(np.random.rand()), ) from arraycontext import from_numpy, to_numpy