From 1bce2d26e7aff51c3071a086cbf60c3d6594c8a5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 27 Nov 2024 13:04:34 -0600 Subject: [PATCH] Introduce ArithArrayContainer --- arraycontext/__init__.py | 10 ++++++++ arraycontext/container/__init__.py | 27 +++++++++++++++++++- arraycontext/context.py | 40 ++++++++++++++---------------- doc/conf.py | 5 ++++ pyproject.toml | 2 ++ 5 files changed, 62 insertions(+), 22 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 1c2ae45..674a229 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -30,6 +30,7 @@ THE SOFTWARE. """ from .container import ( + ArithArrayContainer, ArrayContainer, ArrayContainerT, NotAnArrayContainerError, @@ -73,6 +74,10 @@ from .container.traversal import ( from .context import ( Array, ArrayContext, + ArrayOrArithContainer, + ArrayOrArithContainerOrScalar, + ArrayOrArithContainerOrScalarT, + ArrayOrArithContainerT, ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, @@ -96,10 +101,15 @@ from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag __all__ = ( + "ArithArrayContainer", "Array", "ArrayContainer", "ArrayContainerT", "ArrayContext", + "ArrayOrArithContainer", + "ArrayOrArithContainerOrScalar", + "ArrayOrArithContainerOrScalarT", + "ArrayOrArithContainerT", "ArrayOrContainer", "ArrayOrContainerOrScalar", "ArrayOrContainerOrScalarT", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 75eee2a..afe4a40 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -4,6 +4,7 @@ .. currentmodule:: arraycontext .. autoclass:: ArrayContainer +.. autoclass:: ArithArrayContainer .. class:: ArrayContainerT A type variable with a lower bound of :class:`ArrayContainer`. @@ -87,8 +88,9 @@ from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar # what 'np' is. import numpy import numpy as np +from typing_extensions import Self -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, ArrayOrScalar if TYPE_CHECKING: @@ -145,6 +147,29 @@ class ArrayContainer(Protocol): # that are container-typed. +class ArithArrayContainer(ArrayContainer, Protocol): + """ + A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic. + """ + + # This is loose and permissive, assuming that any array can be added + # to any container. The alternative would be to plaster type-ignores + # on all those uses. Achieving typing precision on what broadcasting is + # allowable seems like a huge endeavor and is likely not feasible without + # a mypy plugin. Maybe some day? -AK, November 2024 + + def __neg__(self) -> Self: ... + def __abs__(self) -> Self: ... + def __add__(self, other: ArrayOrScalar | Self) -> Self: ... + def __radd__(self, other: ArrayOrScalar | Self) -> Self: ... + def __sub__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ... + def __mul__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ... + def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ... + + ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) diff --git a/arraycontext/context.py b/arraycontext/context.py index 6027797..0d0595c 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -87,37 +87,30 @@ Types and Type Variables for Arrays and Containers .. autoclass:: Array -.. class:: ArrayT +.. autodata:: ArrayT A type variable with a lower bound of :class:`Array`. -.. class:: ScalarLike +.. autodata:: ScalarLike A type annotation for scalar types commonly usable with arrays. See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`. -.. class:: ArrayOrContainer +.. autodata:: ArrayOrContainer -.. class:: ArrayOrContainerT +.. autodata:: ArrayOrContainerT A type variable with a lower bound of :class:`ArrayOrContainer`. -.. class:: ArrayOrContainerOrScalar +.. autodata:: ArrayOrContainerOrScalar -.. class:: ArrayOrContainerOrScalarT +.. autodata:: ArrayOrContainerOrScalarT A type variable with a lower bound of :class:`ArrayOrContainerOrScalar`. -Internal typing helpers (do not import) ---------------------------------------- - .. currentmodule:: arraycontext.context -This is only here because the documentation tool wants it. - -.. class:: SelfType - Canonical locations for type annotations ---------------------------------------- @@ -176,16 +169,11 @@ from pytools.tag import ToTagSetConvertible if TYPE_CHECKING: import loopy - from arraycontext.container import ArrayContainer + from arraycontext.container import ArithArrayContainer, ArrayContainer # {{{ typing -ScalarLike = int | float | complex | np.generic - -SelfType = TypeVar("SelfType") - - class Array(Protocol): """A :class:`~typing.Protocol` for the array type supported by :class:`ArrayContext`. @@ -236,16 +224,26 @@ class Array(Protocol): # deprecated, use ScalarLike instead ScalarLike: TypeAlias = int | float | complex | np.generic Scalar = ScalarLike - +ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike) ArrayT = TypeVar("ArrayT", bound=Array) ArrayOrScalar: TypeAlias = "Array | ScalarLike" ArrayOrContainer: TypeAlias = "Array | ArrayContainer" +ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer" ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) +ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer) ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike" +ArrayOrArithContainerOrScalar: TypeAlias = "Array | ArithArrayContainer | ScalarLike" ArrayOrContainerOrScalarT = TypeVar( "ArrayOrContainerOrScalarT", bound=ArrayOrContainerOrScalar) +ArrayOrArithContainerOrScalarT = TypeVar( + "ArrayOrArithContainerOrScalarT", + bound=ArrayOrContainerOrScalar) + + +ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike") + NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike] @@ -494,7 +492,7 @@ class ArrayContext(ABC): return self.tag(tagged, out_ary) @abstractmethod - def clone(self: SelfType) -> SelfType: + def clone(self) -> Self: """If possible, return a version of *self* that is semantically equivalent (i.e. implements all array operations in the same way) but is a separate object. May return *self* if that is not possible. diff --git a/doc/conf.py b/doc/conf.py index 0ba4930..0042ae5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -38,3 +38,8 @@ import sys sys._BUILDING_SPHINX_DOCS = True + + +nitpick_ignore_regex = [ + ["py:class", r"arraycontext\.context\.ContainerOrScalarT"], + ] diff --git a/pyproject.toml b/pyproject.toml index d715981..2e51586 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "immutabledict>=4.1", "numpy", "pytools>=2024.1.3", + # for Self + "typing_extensions>=4", ] [project.optional-dependencies] -- GitLab