diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index e92b527a789c6275f4e4c918ef9299cf7e279064..9eb3c45cc9c95c479a05422c9646e0489b1f2dc5 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -135,7 +135,7 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: for arbitrarily nested structures. The identifiers need to be hashable but are otherwise treated as opaque. """ - raise NotImplementedError(type(ary).__name__) + raise TypeError(f"'{type(ary).__name__}' cannot be serialized as a container") @singledispatch @@ -148,7 +148,8 @@ def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> :param iterable: an iterable that mirrors the output of :meth:`serialize_container`. """ - raise NotImplementedError(type(template).__name__) + raise TypeError( + f"'{type(template).__name__}' cannot be deserialized as a container") def is_array_container_type(cls: type) -> bool: @@ -190,7 +191,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: if ary.dtype.char != "O": raise ValueError( - f"only object arrays are supported, given dtype '{ary.dtype}'") + f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'") # special-cased for speed if ary.ndim == 1: diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index aa91d342ff95adc7bad80142859bc33c8d5234ab..ea5fce9cc34d8530d06e34deef6cb2af55626b65 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -180,7 +180,7 @@ def map_array_container( """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: return f(ary) else: return deserialize_container(ary, [ @@ -265,7 +265,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: raise ValueError( f"Non-array container type has no key: {type(ary).__name__}") else: @@ -287,7 +287,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except NotImplementedError: + except TypeError: return f(keys, _ary) else: return deserialize_container(_ary, [ @@ -316,7 +316,7 @@ def map_reduce_array_container( """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: return map_func(ary) else: return reduce_func([ @@ -391,7 +391,7 @@ def rec_map_reduce_array_container( def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except NotImplementedError: + except TypeError: return map_func(_ary) else: return reduce_func([ @@ -483,7 +483,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: return actx.thaw(ary) else: return deserialize_container(ary, [