From 547899707670b7bbf516bf6f923661a8e0dd538f Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Mon, 27 Sep 2021 10:01:46 -0500 Subject: [PATCH] raise if flattened container does not have homogeneous dtypes --- arraycontext/container/traversal.py | 45 +++++++++++++++++++---------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 06a0726..08d9ba9 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -510,26 +510,36 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: individual leaf arrays appear in the final array is dependent on the order given by :func:`~arraycontext.serialize_container`. """ - @memoize_in(actx, (flatten, "ravel_prg")) - def _ravel_prg(shape: Tuple[int, ...]) -> Any: - raise NotImplementedError + common_dtype = None + result: List[Any] = [] def _flatten(subary: ArrayOrContainerT) -> None: + nonlocal common_dtype + try: iterable = serialize_container(subary) except TypeError: + if common_dtype is None: + common_dtype = subary.dtype + + if subary.dtype != common_dtype: + raise ValueError("arrays in container have different dtypes: " + f"got {subary.dtype}, expected {common_dtype}") + try: flat_subary = actx.np.ravel(subary, order="A") - except ValueError: - flat_subary = actx.call_loopy( - _ravel_prg(subary.shape), ary=subary) + except ValueError as exc: + # NOTE: we can't do much if the array context fails to ravel, + # since it is the one responsible for the actual memory layout + raise NotImplementedError("'flatten' requires advanced reshaping " + "functionality that is not implemented in the array " + f"context '{type(actx).__name__}'") from exc result.append(flat_subary) else: for _, isubary in iterable: _flatten(isubary) - result: List[Any] = [] _flatten(ary) return actx.np.concatenate(result) @@ -547,10 +557,6 @@ def unflatten( # NOTE: https://github.com/python/mypy/issues/7057 offset = 0 - @memoize_in(actx, (unflatten, "reshape_prg")) - def _reshape_prg(shape: Tuple[int, ...]) -> Any: - raise NotImplementedError - def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT: nonlocal offset @@ -562,14 +568,21 @@ def unflatten( if offset > ary.size: raise ValueError("'template' and 'ary' sizes do not match") + if template_subary.dtype != ary.dtype: + raise ValueError("'template' dtype does not match 'ary': " + f"got {template_subary.dtype}, expected {ary.dtype}") + flat_subary = ary[offset - template_subary.size:offset] try: subary = actx.np.reshape(flat_subary, template_subary.shape) - except ValueError: - subary = actx.call_loopy( - _reshape_prg(template_subary.shape), ary=flat_subary) - - return actx.np.astype(subary, template_subary.dtype) + except ValueError as exc: + # NOTE: we can't do much if the array context fails to reshape, + # since it is the one responsible for the actual memory layout + raise NotImplementedError("'unflatten' requires advanced reshaping " + "functionality that is not implemented in the array " + f"context '{type(actx).__name__}'") from exc + + return subary else: return deserialize_container(template_subary, [ (key, _unflatten(isubary)) for key, isubary in iterable -- GitLab