Skip to content
Snippets Groups Projects
Commit 1db13f2e authored by Alexandru Fikl's avatar Alexandru Fikl Committed by Andreas Klöckner
Browse files

add DeviceArray type alias for typing

parent a2620be2
No related branches found
No related tags found
No related merge requests found
Pipeline #226119 passed
......@@ -29,7 +29,7 @@ THE SOFTWARE.
"""
import sys
from .context import ArrayContext
from .context import ArrayContext, DeviceArray, DeviceScalar
from .transform_metadata import (CommonSubexpressionTag,
ElementwiseMapKernelTag)
......@@ -74,7 +74,7 @@ from .loopy import make_loopy_program
__all__ = (
"ArrayContext",
"ArrayContext", "DeviceScalar", "DeviceArray",
"CommonSubexpressionTag",
"ElementwiseMapKernelTag",
......
......@@ -67,7 +67,7 @@ from functools import update_wrapper, partial, singledispatch
import numpy as np
from arraycontext.context import ArrayContext
from arraycontext.context import ArrayContext, DeviceArray
from arraycontext.container import (
ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
serialize_container, deserialize_container)
......@@ -355,7 +355,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
def map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any],
ary: ArrayOrContainerT) -> Any:
ary: ArrayOrContainerT) -> "DeviceArray":
"""Perform a map-reduce over array containers.
:param reduce_func: callable used to reduce over the components of *ary*
......@@ -378,7 +378,7 @@ def map_reduce_array_container(
def multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any],
*args: Any) -> Any:
*args: Any) -> "DeviceArray":
r"""Perform a map-reduce over multiple array containers.
:param reduce_func: callable used to reduce over the components of any
......@@ -401,7 +401,7 @@ def multimap_reduce_array_container(
def rec_map_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[[Any], Any],
ary: ArrayOrContainerT) -> Any:
ary: ArrayOrContainerT) -> "DeviceArray":
"""Perform a map-reduce over array containers recursively.
:param reduce_func: callable used to reduce over the components of *ary*
......@@ -455,7 +455,7 @@ def rec_map_reduce_array_container(
def rec_multimap_reduce_array_container(
reduce_func: Callable[[Iterable[Any]], Any],
map_func: Callable[..., Any],
*args: Any) -> Any:
*args: Any) -> "DeviceArray":
r"""Perform a map-reduce over multiple array containers recursively.
:param reduce_func: callable used to reduce over the components of any
......
......@@ -74,6 +74,19 @@ The interface of an array context
---------------------------------
.. currentmodule:: arraycontext
.. class:: DeviceArray
A (type alias for an) array type supported by the :class:`ArrayContext`
meant to aid in typing annotations. For a explicit list of supported types
see :attr:`ArrayContext.array_types`.
.. class:: DeviceScalar
A (type alias for a) scalar type supported by the :class:`ArrayContext`
meant to aid in typing annotations, e.g. for reductions. In :mod:`numpy`
terminology, this is just an array with a shape of ``()``.
.. autoclass:: ArrayContext
"""
......@@ -110,6 +123,10 @@ from pytools import memoize_method
from pytools.tag import Tag
DeviceArray = Any
DeviceScalar = Any
# {{{ ArrayContext
class ArrayContext(ABC):
......
......@@ -14,6 +14,11 @@ exec(compile(open("../arraycontext/version.py").read(), "../arraycontext/version
version = ".".join(str(x) for x in ver_dic["VERSION"])
release = ver_dic["VERSION_TEXT"]
autodoc_type_aliases = {
"DeviceScalar": "arraycontext.DeviceScalar",
"DeviceArray": "arraycontext.DeviceArray",
}
intersphinx_mapping = {
"https://docs.python.org/3/": None,
"https://numpy.org/doc/stable/": None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment