From f0e3ac009cab332aa6021bc8633e23922d517cf8 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Wed, 20 Oct 2021 19:51:07 -0500 Subject: [PATCH] add specific array container exception type --- arraycontext/__init__.py | 4 +-- arraycontext/container/__init__.py | 19 +++++++++---- arraycontext/container/traversal.py | 36 +++++++++++------------- arraycontext/fake_numpy.py | 4 +-- arraycontext/impl/pyopencl/fake_numpy.py | 4 +-- arraycontext/impl/pytato/fake_numpy.py | 4 +-- 6 files changed, 37 insertions(+), 34 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 0147e62..2297765 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -38,7 +38,7 @@ from .transform_metadata import (CommonSubexpressionTag, from .metadata import _FirstAxisIsElementsTag from .container import ( - ArrayContainer, + ArrayContainer, NotAnArrayContainerError, is_array_container, is_array_container_type, get_container_context, get_container_context_recursively, serialize_container, deserialize_container, @@ -79,7 +79,7 @@ __all__ = ( "CommonSubexpressionTag", "ElementwiseMapKernelTag", - "ArrayContainer", + "ArrayContainer", "NotAnArrayContainerError", "is_array_container", "is_array_container_type", "get_container_context", "get_container_context_recursively", "serialize_container", "deserialize_container", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index bae0bcd..72bd024 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -20,6 +20,8 @@ .. autoclass:: ArrayContainer +.. autoexception:: NotAnArrayContainerError + Serialization/deserialization ----------------------------- .. autofunction:: is_array_container_type @@ -115,6 +117,10 @@ class ArrayContainer: """ +class NotAnArrayContainerError(TypeError): + """:class:`TypeError` subclass raised when an array container is expected.""" + + @singledispatch def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: r"""Serialize the array container into an iterable over its components. @@ -137,7 +143,8 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: for arbitrarily nested structures. The identifiers need to be hashable but are otherwise treated as opaque. """ - raise TypeError(f"'{type(ary).__name__}' cannot be serialized as a container") + raise NotAnArrayContainerError( + f"'{type(ary).__name__}' cannot be serialized as a container") @singledispatch @@ -150,7 +157,7 @@ def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> :param iterable: an iterable that mirrors the output of :meth:`serialize_container`. """ - raise TypeError( + raise NotAnArrayContainerError( f"'{type(template).__name__}' cannot be deserialized as a container") @@ -181,8 +188,8 @@ def is_array_container(ary: Any) -> bool: from warnings import warn warn("is_array_container is deprecated and will be removed in 2022. " "If you must know precisely whether something is an array container, " - "try serializing it and catch TypeError. For a cheaper option, see " - "is_array_container_type.", + "try serializing it and catch NotAnArrayContainerError. For a " + "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) is not serialize_container.__wrapped__) # type:ignore[attr-defined] @@ -206,7 +213,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: @serialize_container.register(np.ndarray) def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: if ary.dtype.char != "O": - raise TypeError( + raise NotAnArrayContainerError( f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'") # special-cased for speed @@ -254,7 +261,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return actx else: for _, subary in iterable: diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index badb7d6..c339609 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -69,7 +69,7 @@ import numpy as np from arraycontext.context import ArrayContext from arraycontext.container import ( - ContainerT, ArrayOrContainerT, + ContainerT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container) @@ -93,7 +93,7 @@ def _map_array_container_impl( try: iterable = serialize_container(_ary) - except TypeError: + except NotAnArrayContainerError: return f(_ary) else: return deserialize_container(_ary, [ @@ -127,7 +127,7 @@ def _multimap_array_container_impl( try: iterable_template = serialize_container(template_ary) - except TypeError: + except NotAnArrayContainerError: return f(*_args) else: pass @@ -170,7 +170,7 @@ def _multimap_array_container_impl( # FIXME: this will serialize again once `rec` is called, which is # not great, but it doesn't seem like there's a good way to avoid it _ = serialize_container(arg) - except TypeError: + except NotAnArrayContainerError: pass else: container_indices.append(i) @@ -231,7 +231,7 @@ def map_array_container( """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return f(ary) else: return deserialize_container(ary, [ @@ -316,7 +316,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: raise ValueError( f"Non-array container type has no key: {type(ary).__name__}") else: @@ -338,7 +338,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except TypeError: + except NotAnArrayContainerError: return f(keys, _ary) else: return deserialize_container(_ary, [ @@ -367,7 +367,7 @@ def map_reduce_array_container( """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return map_func(ary) else: return reduce_func([ @@ -442,7 +442,7 @@ def rec_map_reduce_array_container( def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except TypeError: + except NotAnArrayContainerError: return map_func(_ary) else: return reduce_func([ @@ -501,7 +501,7 @@ def freeze( """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: if actx is None: raise TypeError( f"cannot freeze arrays of type {type(ary).__name__} " @@ -538,7 +538,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return actx.thaw(ary) else: return deserialize_container(ary, [ @@ -567,7 +567,7 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: try: iterable = serialize_container(subary) - except TypeError: + except NotAnArrayContainerError: if common_dtype is None: common_dtype = subary.dtype @@ -618,7 +618,7 @@ def unflatten( try: iterable = serialize_container(template_subary) - except TypeError: + except NotAnArrayContainerError: if (offset + template_subary.size) > ary.size: raise ValueError("'template' and 'ary' sizes do not match: " "'template' is too large") @@ -682,9 +682,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any: The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ def _from_numpy_with_check(subary: Any) -> Any: - if np.isscalar(subary): - return subary - elif isinstance(subary, np.ndarray): + if isinstance(subary, np.ndarray) or np.isscalar(subary): return actx.from_numpy(subary) else: raise TypeError(f"array is not an ndarray: '{type(subary).__name__}'") @@ -699,9 +697,7 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any: The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`. """ def _to_numpy_with_check(subary: Any) -> Any: - if np.isscalar(subary): - return subary - elif isinstance(subary, actx.array_types): + if isinstance(subary, actx.array_types) or np.isscalar(subary): return actx.to_numpy(subary) else: raise TypeError( @@ -734,7 +730,7 @@ def outer(a: Any, b: Any) -> Any: def treat_as_scalar(x: Any) -> bool: try: serialize_container(x) - except TypeError: + except NotAnArrayContainerError: return True else: return ( diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index fec1463..022e053 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -24,7 +24,7 @@ THE SOFTWARE. import numpy as np -from arraycontext.container import serialize_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, multimapped_over_array_containers) from pytools import memoize_in @@ -258,7 +258,7 @@ class BaseFakeNumpyLinalgNamespace: try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: pass else: return _reduce_norm(actx, [ diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 83c1b43..4e9d48b 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -33,7 +33,7 @@ import numpy as np from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace -from arraycontext.container import serialize_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, rec_multimap_array_container, @@ -252,7 +252,7 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): try: iterable = zip(serialize_container(x), serialize_container(y)) - except TypeError: + except NotAnArrayContainerError: if x.shape != y.shape: return false else: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 389ea07..905fd0f 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -28,7 +28,7 @@ import numpy as np from arraycontext.fake_numpy import ( BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, ) -from arraycontext.container import serialize_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, rec_multimap_array_container, @@ -186,7 +186,7 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): try: iterable = zip(serialize_container(x), serialize_container(y)) - except TypeError: + except NotAnArrayContainerError: if x.shape != y.shape: return false else: -- GitLab