From ac8d9d2691987aa65ab3f477e82d24409f88f885 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 22 Sep 2021 16:43:42 -0500
Subject: [PATCH] raise TypeError instead of NotImplementedError in
 de/serialize_container

---
 arraycontext/container/__init__.py  |  7 ++++---
 arraycontext/container/traversal.py | 12 ++++++------
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index e92b527..9eb3c45 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -135,7 +135,7 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]:
         for arbitrarily nested structures. The identifiers need to be hashable
         but are otherwise treated as opaque.
     """
-    raise NotImplementedError(type(ary).__name__)
+    raise TypeError(f"'{type(ary).__name__}' cannot be serialized as a container")
 
 
 @singledispatch
@@ -148,7 +148,8 @@ def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) ->
     :param iterable: an iterable that mirrors the output of
         :meth:`serialize_container`.
     """
-    raise NotImplementedError(type(template).__name__)
+    raise TypeError(
+            f"'{type(template).__name__}' cannot be deserialized as a container")
 
 
 def is_array_container_type(cls: type) -> bool:
@@ -190,7 +191,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
 def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
     if ary.dtype.char != "O":
         raise ValueError(
-                f"only object arrays are supported, given dtype '{ary.dtype}'")
+                f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'")
 
     # special-cased for speed
     if ary.ndim == 1:
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index aa91d34..ea5fce9 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -180,7 +180,7 @@ def map_array_container(
     """
     try:
         iterable = serialize_container(ary)
-    except NotImplementedError:
+    except TypeError:
         return f(ary)
     else:
         return deserialize_container(ary, [
@@ -265,7 +265,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any],
     """
     try:
         iterable = serialize_container(ary)
-    except NotImplementedError:
+    except TypeError:
         raise ValueError(
                 f"Non-array container type has no key: {type(ary).__name__}")
     else:
@@ -287,7 +287,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
             _ary: ArrayOrContainerT) -> ArrayOrContainerT:
         try:
             iterable = serialize_container(_ary)
-        except NotImplementedError:
+        except TypeError:
             return f(keys, _ary)
         else:
             return deserialize_container(_ary, [
@@ -316,7 +316,7 @@ def map_reduce_array_container(
     """
     try:
         iterable = serialize_container(ary)
-    except NotImplementedError:
+    except TypeError:
         return map_func(ary)
     else:
         return reduce_func([
@@ -391,7 +391,7 @@ def rec_map_reduce_array_container(
     def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
         try:
             iterable = serialize_container(_ary)
-        except NotImplementedError:
+        except TypeError:
             return map_func(_ary)
         else:
             return reduce_func([
@@ -483,7 +483,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
     """
     try:
         iterable = serialize_container(ary)
-    except NotImplementedError:
+    except TypeError:
         return actx.thaw(ary)
     else:
         return deserialize_container(ary, [
-- 
GitLab