From ac8d9d2691987aa65ab3f477e82d24409f88f885 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Wed, 22 Sep 2021 16:43:42 -0500 Subject: [PATCH] raise TypeError instead of NotImplementedError in de/serialize_container --- arraycontext/container/__init__.py | 7 ++++--- arraycontext/container/traversal.py | 12 ++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index e92b527..9eb3c45 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -135,7 +135,7 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: for arbitrarily nested structures. The identifiers need to be hashable but are otherwise treated as opaque. """ - raise NotImplementedError(type(ary).__name__) + raise TypeError(f"'{type(ary).__name__}' cannot be serialized as a container") @singledispatch @@ -148,7 +148,8 @@ def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> :param iterable: an iterable that mirrors the output of :meth:`serialize_container`. """ - raise NotImplementedError(type(template).__name__) + raise TypeError( + f"'{type(template).__name__}' cannot be deserialized as a container") def is_array_container_type(cls: type) -> bool: @@ -190,7 +191,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: if ary.dtype.char != "O": raise ValueError( - f"only object arrays are supported, given dtype '{ary.dtype}'") + f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'") # special-cased for speed if ary.ndim == 1: diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index aa91d34..ea5fce9 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -180,7 +180,7 @@ def map_array_container( """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: return f(ary) else: return deserialize_container(ary, [ @@ -265,7 +265,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: raise ValueError( f"Non-array container type has no key: {type(ary).__name__}") else: @@ -287,7 +287,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except NotImplementedError: + except TypeError: return f(keys, _ary) else: return deserialize_container(_ary, [ @@ -316,7 +316,7 @@ def map_reduce_array_container( """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: return map_func(ary) else: return reduce_func([ @@ -391,7 +391,7 @@ def rec_map_reduce_array_container( def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except NotImplementedError: + except TypeError: return map_func(_ary) else: return reduce_func([ @@ -483,7 +483,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: """ try: iterable = serialize_container(ary) - except NotImplementedError: + except TypeError: return actx.thaw(ary) else: return deserialize_container(ary, [ -- GitLab