diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 22977658c47c4a56c12df59f020567e1ee71d3ff..6f7308dbe1221a755aa05667405f66dbc45444b0 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -29,7 +29,7 @@ THE SOFTWARE. """ import sys -from .context import ArrayContext +from .context import ArrayContext, DeviceArray, DeviceScalar from .transform_metadata import (CommonSubexpressionTag, ElementwiseMapKernelTag) @@ -74,7 +74,7 @@ from .loopy import make_loopy_program __all__ = ( - "ArrayContext", + "ArrayContext", "DeviceScalar", "DeviceArray", "CommonSubexpressionTag", "ElementwiseMapKernelTag", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 3eb3664f4c9be355424d2fb4ef0963d9d4b91337..aa6de375076384afbc8f16e04f195327e5db824a 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -67,7 +67,7 @@ from functools import update_wrapper, partial, singledispatch import numpy as np -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, DeviceArray from arraycontext.container import ( ContainerT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container) @@ -355,7 +355,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], def map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], - ary: ArrayOrContainerT) -> Any: + ary: ArrayOrContainerT) -> "DeviceArray": """Perform a map-reduce over array containers. :param reduce_func: callable used to reduce over the components of *ary* @@ -378,7 +378,7 @@ def map_reduce_array_container( def multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], - *args: Any) -> Any: + *args: Any) -> "DeviceArray": r"""Perform a map-reduce over multiple array containers. :param reduce_func: callable used to reduce over the components of any @@ -401,7 +401,7 @@ def multimap_reduce_array_container( def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], - ary: ArrayOrContainerT) -> Any: + ary: ArrayOrContainerT) -> "DeviceArray": """Perform a map-reduce over array containers recursively. :param reduce_func: callable used to reduce over the components of *ary* @@ -455,7 +455,7 @@ def rec_map_reduce_array_container( def rec_multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], - *args: Any) -> Any: + *args: Any) -> "DeviceArray": r"""Perform a map-reduce over multiple array containers recursively. :param reduce_func: callable used to reduce over the components of any diff --git a/arraycontext/context.py b/arraycontext/context.py index 379c8e92d8aefb1b0c5cf23b8b485314029b2bef..fa705136a8ceb7d224a75b1ddee33e6cc5408682 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -74,6 +74,19 @@ The interface of an array context --------------------------------- .. currentmodule:: arraycontext + +.. class:: DeviceArray + + A (type alias for an) array type supported by the :class:`ArrayContext` + meant to aid in typing annotations. For a explicit list of supported types + see :attr:`ArrayContext.array_types`. + +.. class:: DeviceScalar + + A (type alias for a) scalar type supported by the :class:`ArrayContext` + meant to aid in typing annotations, e.g. for reductions. In :mod:`numpy` + terminology, this is just an array with a shape of ``()``. + .. autoclass:: ArrayContext """ @@ -110,6 +123,10 @@ from pytools import memoize_method from pytools.tag import Tag +DeviceArray = Any +DeviceScalar = Any + + # {{{ ArrayContext class ArrayContext(ABC): diff --git a/doc/conf.py b/doc/conf.py index bee0e10b98630bf89addaabe7931fd5eeaf86478..29f026ea52d931b79d1b070ba96aac85dd40931c 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -14,6 +14,11 @@ exec(compile(open("../arraycontext/version.py").read(), "../arraycontext/version version = ".".join(str(x) for x in ver_dic["VERSION"]) release = ver_dic["VERSION_TEXT"] +autodoc_type_aliases = { + "DeviceScalar": "arraycontext.DeviceScalar", + "DeviceArray": "arraycontext.DeviceArray", + } + intersphinx_mapping = { "https://docs.python.org/3/": None, "https://numpy.org/doc/stable/": None,