diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 0147e62910f8681c1af9cd8d1aa125d79597b09b..22977658c47c4a56c12df59f020567e1ee71d3ff 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -38,7 +38,7 @@ from .transform_metadata import (CommonSubexpressionTag, from .metadata import _FirstAxisIsElementsTag from .container import ( - ArrayContainer, + ArrayContainer, NotAnArrayContainerError, is_array_container, is_array_container_type, get_container_context, get_container_context_recursively, serialize_container, deserialize_container, @@ -79,7 +79,7 @@ __all__ = ( "CommonSubexpressionTag", "ElementwiseMapKernelTag", - "ArrayContainer", + "ArrayContainer", "NotAnArrayContainerError", "is_array_container", "is_array_container_type", "get_container_context", "get_container_context_recursively", "serialize_container", "deserialize_container", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index bae0bcdfefc20244c49f224e8756dd35ce8eafe9..72bd024399ffde45f54b9efa9251ec66afcd2c1a 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -20,6 +20,8 @@ .. autoclass:: ArrayContainer +.. autoexception:: NotAnArrayContainerError + Serialization/deserialization ----------------------------- .. autofunction:: is_array_container_type @@ -115,6 +117,10 @@ class ArrayContainer: """ +class NotAnArrayContainerError(TypeError): + """:class:`TypeError` subclass raised when an array container is expected.""" + + @singledispatch def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: r"""Serialize the array container into an iterable over its components. @@ -137,7 +143,8 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: for arbitrarily nested structures. The identifiers need to be hashable but are otherwise treated as opaque. """ - raise TypeError(f"'{type(ary).__name__}' cannot be serialized as a container") + raise NotAnArrayContainerError( + f"'{type(ary).__name__}' cannot be serialized as a container") @singledispatch @@ -150,7 +157,7 @@ def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> :param iterable: an iterable that mirrors the output of :meth:`serialize_container`. """ - raise TypeError( + raise NotAnArrayContainerError( f"'{type(template).__name__}' cannot be deserialized as a container") @@ -181,8 +188,8 @@ def is_array_container(ary: Any) -> bool: from warnings import warn warn("is_array_container is deprecated and will be removed in 2022. " "If you must know precisely whether something is an array container, " - "try serializing it and catch TypeError. For a cheaper option, see " - "is_array_container_type.", + "try serializing it and catch NotAnArrayContainerError. For a " + "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) is not serialize_container.__wrapped__) # type:ignore[attr-defined] @@ -206,7 +213,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: @serialize_container.register(np.ndarray) def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: if ary.dtype.char != "O": - raise TypeError( + raise NotAnArrayContainerError( f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'") # special-cased for speed @@ -254,7 +261,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return actx else: for _, subary in iterable: diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index badb7d6052026dcacc2a6ec2e2f267bf5f58ef2c..c339609c98cde40599bb90555339f3937cc45a65 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -69,7 +69,7 @@ import numpy as np from arraycontext.context import ArrayContext from arraycontext.container import ( - ContainerT, ArrayOrContainerT, + ContainerT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container) @@ -93,7 +93,7 @@ def _map_array_container_impl( try: iterable = serialize_container(_ary) - except TypeError: + except NotAnArrayContainerError: return f(_ary) else: return deserialize_container(_ary, [ @@ -127,7 +127,7 @@ def _multimap_array_container_impl( try: iterable_template = serialize_container(template_ary) - except TypeError: + except NotAnArrayContainerError: return f(*_args) else: pass @@ -170,7 +170,7 @@ def _multimap_array_container_impl( # FIXME: this will serialize again once `rec` is called, which is # not great, but it doesn't seem like there's a good way to avoid it _ = serialize_container(arg) - except TypeError: + except NotAnArrayContainerError: pass else: container_indices.append(i) @@ -231,7 +231,7 @@ def map_array_container( """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return f(ary) else: return deserialize_container(ary, [ @@ -316,7 +316,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: raise ValueError( f"Non-array container type has no key: {type(ary).__name__}") else: @@ -338,7 +338,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except TypeError: + except NotAnArrayContainerError: return f(keys, _ary) else: return deserialize_container(_ary, [ @@ -367,7 +367,7 @@ def map_reduce_array_container( """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return map_func(ary) else: return reduce_func([ @@ -442,7 +442,7 @@ def rec_map_reduce_array_container( def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) - except TypeError: + except NotAnArrayContainerError: return map_func(_ary) else: return reduce_func([ @@ -501,7 +501,7 @@ def freeze( """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: if actx is None: raise TypeError( f"cannot freeze arrays of type {type(ary).__name__} " @@ -538,7 +538,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: """ try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: return actx.thaw(ary) else: return deserialize_container(ary, [ @@ -567,7 +567,7 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: try: iterable = serialize_container(subary) - except TypeError: + except NotAnArrayContainerError: if common_dtype is None: common_dtype = subary.dtype @@ -618,7 +618,7 @@ def unflatten( try: iterable = serialize_container(template_subary) - except TypeError: + except NotAnArrayContainerError: if (offset + template_subary.size) > ary.size: raise ValueError("'template' and 'ary' sizes do not match: " "'template' is too large") @@ -682,9 +682,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any: The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ def _from_numpy_with_check(subary: Any) -> Any: - if np.isscalar(subary): - return subary - elif isinstance(subary, np.ndarray): + if isinstance(subary, np.ndarray) or np.isscalar(subary): return actx.from_numpy(subary) else: raise TypeError(f"array is not an ndarray: '{type(subary).__name__}'") @@ -699,9 +697,7 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any: The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`. """ def _to_numpy_with_check(subary: Any) -> Any: - if np.isscalar(subary): - return subary - elif isinstance(subary, actx.array_types): + if isinstance(subary, actx.array_types) or np.isscalar(subary): return actx.to_numpy(subary) else: raise TypeError( @@ -734,7 +730,7 @@ def outer(a: Any, b: Any) -> Any: def treat_as_scalar(x: Any) -> bool: try: serialize_container(x) - except TypeError: + except NotAnArrayContainerError: return True else: return ( diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index fec1463b74fd74afaecc45f66e318e1ba8069073..022e0537b7acf07a23f915b2d0c97b0d96c056d1 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -24,7 +24,7 @@ THE SOFTWARE. import numpy as np -from arraycontext.container import serialize_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, multimapped_over_array_containers) from pytools import memoize_in @@ -258,7 +258,7 @@ class BaseFakeNumpyLinalgNamespace: try: iterable = serialize_container(ary) - except TypeError: + except NotAnArrayContainerError: pass else: return _reduce_norm(actx, [ diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 83c1b43d1c755b903f4e9ca4ddc8aa64c5a1f598..4e9d48beca3de0b505591300b2f17df1e1de31ac 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -33,7 +33,7 @@ import numpy as np from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace -from arraycontext.container import serialize_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, rec_multimap_array_container, @@ -252,7 +252,7 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): try: iterable = zip(serialize_container(x), serialize_container(y)) - except TypeError: + except NotAnArrayContainerError: if x.shape != y.shape: return false else: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 389ea07a6de8cadb927dd2bafb6ba41c6870e2a4..905fd0f873ba86597eed79ac0387ca9e8a494ac9 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -28,7 +28,7 @@ import numpy as np from arraycontext.fake_numpy import ( BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, ) -from arraycontext.container import serialize_container +from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, rec_multimap_array_container, @@ -186,7 +186,7 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): try: iterable = zip(serialize_container(x), serialize_container(y)) - except TypeError: + except NotAnArrayContainerError: if x.shape != y.shape: return false else: