From 456d8933a2fa265cee30d5aaf34f903ddc1063eb Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Tue, 5 Apr 2022 10:04:24 -0500 Subject: [PATCH] allow dataclass containers with only DeviceArrays --- arraycontext/__init__.py | 26 ++++++----- arraycontext/container/dataclass.py | 3 +- arraycontext/container/traversal.py | 10 ++--- arraycontext/context.py | 68 +++++++++++++++++++++++------ doc/conf.py | 5 --- setup.py | 1 + test/test_utils.py | 13 ++++++ 7 files changed, 90 insertions(+), 36 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 8206fb8..e8f34d4 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -29,7 +29,7 @@ THE SOFTWARE. """ import sys -from .context import ArrayContext, DeviceArray, DeviceScalar +from .context import ArrayContext, Array, Scalar from .transform_metadata import (CommonSubexpressionTag, ElementwiseMapKernelTag) @@ -75,7 +75,7 @@ from .loopy import make_loopy_program __all__ = ( - "ArrayContext", "DeviceScalar", "DeviceArray", + "ArrayContext", "Scalar", "Array", "CommonSubexpressionTag", "ElementwiseMapKernelTag", @@ -125,24 +125,26 @@ def _deprecated_acf(): _depr_name_to_replacement_and_obj = { - "get_container_context": ("get_container_context_opt", - get_container_context_opt), - "FirstAxisIsElementsTag": - ("meshmode.transform_metadata.FirstAxisIsElementsTag", - _FirstAxisIsElementsTag), - "_acf": - ("<no replacement yet>", _deprecated_acf), + "get_container_context": ( + "get_container_context_opt", + get_container_context_opt, 2022), + "FirstAxisIsElementsTag": ( + "meshmode.transform_metadata.FirstAxisIsElementsTag", + _FirstAxisIsElementsTag, 2022), + "_acf": ("<no replacement yet>", _deprecated_acf, 2022), + "DeviceArray": ("Array", Array, 2023), + "DeviceScalar": ("Scalar", Scalar, 2023), } if sys.version_info >= (3, 7): def __getattr__(name): replacement_and_obj = _depr_name_to_replacement_and_obj.get(name, None) if replacement_and_obj is not None: - replacement, obj = replacement_and_obj + replacement, obj, year = replacement_and_obj from warnings import warn warn(f"'arraycontext.{name}' is deprecated. " f"Use '{replacement}' instead. " - f"'arraycontext.{name}' will continue to work until 2022.", + f"'arraycontext.{name}' will continue to work until {year}.", DeprecationWarning, stacklevel=2) return obj else: @@ -151,6 +153,8 @@ else: FirstAxisIsElementsTag = _FirstAxisIsElementsTag _acf = _deprecated_acf get_container_context = get_container_context_opt + DeviceArray = Array + DeviceScalar = Scalar # }}} diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 203246c..150d1d6 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -70,7 +70,8 @@ def dataclass_array_container(cls: type) -> type: f"field '{f.name}' not an instance of 'type': " f"'{f.type!r}'") - return is_array_container_type(f.type) + from arraycontext import Array + return f.type is Array or is_array_container_type(f.type) from pytools import partition array_fields, non_array_fields = partition(is_array_field, fields(cls)) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 23cec03..de89a6b 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -68,7 +68,7 @@ from functools import update_wrapper, partial, singledispatch import numpy as np -from arraycontext.context import ArrayContext, DeviceArray, _ScalarLike +from arraycontext.context import ArrayContext, Array, _ScalarLike from arraycontext.container import ( ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container) @@ -384,7 +384,7 @@ def rec_keyed_map_array_container( def map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], - ary: ArrayOrContainerT) -> "DeviceArray": + ary: ArrayOrContainerT) -> "Array": """Perform a map-reduce over array containers. :param reduce_func: callable used to reduce over the components of *ary* @@ -407,7 +407,7 @@ def map_reduce_array_container( def multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], - *args: Any) -> "DeviceArray": + *args: Any) -> "Array": r"""Perform a map-reduce over multiple array containers. :param reduce_func: callable used to reduce over the components of any @@ -431,7 +431,7 @@ def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], ary: ArrayOrContainerT, - leaf_class: Optional[type] = None) -> "DeviceArray": + leaf_class: Optional[type] = None) -> "Array": """Perform a map-reduce over array containers recursively. :param reduce_func: callable used to reduce over the components of *ary* @@ -489,7 +489,7 @@ def rec_multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], *args: Any, - leaf_class: Optional[type] = None) -> "DeviceArray": + leaf_class: Optional[type] = None) -> "Array": r"""Perform a map-reduce over multiple array containers recursively. :param reduce_func: callable used to reduce over the components of any diff --git a/arraycontext/context.py b/arraycontext/context.py index 6c13a33..d206a87 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -75,18 +75,8 @@ The interface of an array context .. currentmodule:: arraycontext -.. class:: DeviceArray - - A (type alias for an) array type supported by the :class:`ArrayContext` - meant to aid in typing annotations. For a explicit list of supported types - see :attr:`ArrayContext.array_types`. - -.. class:: DeviceScalar - - A (type alias for a) scalar type supported by the :class:`ArrayContext` - meant to aid in typing annotations, e.g. for reductions. In :mod:`numpy` - terminology, this is just an array with a shape of ``()``. - +.. autoclass:: Array +.. autoclass:: Scalar .. autoclass:: ArrayContext """ @@ -123,10 +113,60 @@ from pytools import memoize_method from pytools.tag import Tag -DeviceArray = Any -DeviceScalar = Any +# {{{ typing + _ScalarLike = Union[int, float, complex, np.generic] +try: + from typing import Protocol +except ImportError: + from typing_extensions import Protocol # type: ignore[misc] + + +class Array(Protocol): + """A :class:`~typing.Protocol` for the array type supported by + :class:`ArrayContext`. + + This is meant to aid in typing annotations. For a explicit list of + supported types see :attr:`ArrayContext.array_types`. + + .. attribute:: shape + .. attribute:: dtype + """ + + @property + def shape(self) -> Tuple[int, ...]: + ... + + @property + def dtype(self) -> "np.dtype[Any]": + ... + + +class Scalar(Protocol): + """A :class:`~typing.Protocol` for the scalar type supported by + :class:`ArrayContext`. + + In :mod:`numpy` terminology, this is just an array with a shape of ``()``. + + This is meant to aid in typing annotations. For a explicit list of + supported types see :attr:`ArrayContext.array_types`. + + .. attribute:: shape + .. attribute:: dtype + """ + + @property + def shape(self) -> Tuple[()]: + ... + + @property + def dtype(self) -> "np.dtype[Any]": + ... + + +# }}} + # {{{ ArrayContext diff --git a/doc/conf.py b/doc/conf.py index 29f026e..bee0e10 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -14,11 +14,6 @@ exec(compile(open("../arraycontext/version.py").read(), "../arraycontext/version version = ".".join(str(x) for x in ver_dic["VERSION"]) release = ver_dic["VERSION_TEXT"] -autodoc_type_aliases = { - "DeviceScalar": "arraycontext.DeviceScalar", - "DeviceArray": "arraycontext.DeviceArray", - } - intersphinx_mapping = { "https://docs.python.org/3/": None, "https://numpy.org/doc/stable/": None, diff --git a/setup.py b/setup.py index 8b0d677..06e898d 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def main(): "pytest>=2.3", "loopy>=2019.1", "dataclasses; python_version<'3.7'", + "typing_extensions; python_version<'3.8'", "types-dataclasses", ], package_data={"arraycontext": ["py.typed"]}, diff --git a/test/test_utils.py b/test/test_utils.py index 08b6c3a..ac3127f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -90,6 +90,19 @@ def test_dataclass_array_container(): # }}} + # {{{ device arrays + + from arraycontext import Array + + @dataclass + class ArrayContainerWithArray: + x: Array + y: Array + + dataclass_array_container(ArrayContainerWithArray) + + # }}} + # }}} -- GitLab