From 1bce2d26e7aff51c3071a086cbf60c3d6594c8a5 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 27 Nov 2024 13:04:34 -0600
Subject: [PATCH] Introduce ArithArrayContainer

---
 arraycontext/__init__.py           | 10 ++++++++
 arraycontext/container/__init__.py | 27 +++++++++++++++++++-
 arraycontext/context.py            | 40 ++++++++++++++----------------
 doc/conf.py                        |  5 ++++
 pyproject.toml                     |  2 ++
 5 files changed, 62 insertions(+), 22 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 1c2ae45..674a229 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -30,6 +30,7 @@ THE SOFTWARE.
 """
 
 from .container import (
+    ArithArrayContainer,
     ArrayContainer,
     ArrayContainerT,
     NotAnArrayContainerError,
@@ -73,6 +74,10 @@ from .container.traversal import (
 from .context import (
     Array,
     ArrayContext,
+    ArrayOrArithContainer,
+    ArrayOrArithContainerOrScalar,
+    ArrayOrArithContainerOrScalarT,
+    ArrayOrArithContainerT,
     ArrayOrContainer,
     ArrayOrContainerOrScalar,
     ArrayOrContainerOrScalarT,
@@ -96,10 +101,15 @@ from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
 
 
 __all__ = (
+    "ArithArrayContainer",
     "Array",
     "ArrayContainer",
     "ArrayContainerT",
     "ArrayContext",
+    "ArrayOrArithContainer",
+    "ArrayOrArithContainerOrScalar",
+    "ArrayOrArithContainerOrScalarT",
+    "ArrayOrArithContainerT",
     "ArrayOrContainer",
     "ArrayOrContainerOrScalar",
     "ArrayOrContainerOrScalarT",
diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 75eee2a..afe4a40 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -4,6 +4,7 @@
 .. currentmodule:: arraycontext
 
 .. autoclass:: ArrayContainer
+.. autoclass:: ArithArrayContainer
 .. class:: ArrayContainerT
 
     A type variable with a lower bound of :class:`ArrayContainer`.
@@ -87,8 +88,9 @@ from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
 # what 'np' is.
 import numpy
 import numpy as np
+from typing_extensions import Self
 
-from arraycontext.context import ArrayContext
+from arraycontext.context import ArrayContext, ArrayOrScalar
 
 
 if TYPE_CHECKING:
@@ -145,6 +147,29 @@ class ArrayContainer(Protocol):
     # that are container-typed.
 
 
+class ArithArrayContainer(ArrayContainer, Protocol):
+    """
+    A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
+    """
+
+    # This is loose and permissive, assuming that any array can be added
+    # to any container. The alternative would be to plaster type-ignores
+    # on all those uses. Achieving typing precision on what broadcasting is
+    # allowable seems like a huge endeavor and is likely not feasible without
+    # a mypy plugin. Maybe some day? -AK, November 2024
+
+    def __neg__(self) -> Self: ...
+    def __abs__(self) -> Self: ...
+    def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
+    def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
+    def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
+    def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
+    def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
+    def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
+    def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
+    def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...
+
+
 ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
 
 
diff --git a/arraycontext/context.py b/arraycontext/context.py
index 6027797..0d0595c 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -87,37 +87,30 @@ Types and Type Variables for Arrays and Containers
 
 .. autoclass:: Array
 
-.. class:: ArrayT
+.. autodata:: ArrayT
 
     A type variable with a lower bound of :class:`Array`.
 
-.. class:: ScalarLike
+.. autodata:: ScalarLike
 
     A type annotation for scalar types commonly usable with arrays.
 
 See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`.
 
-.. class:: ArrayOrContainer
+.. autodata:: ArrayOrContainer
 
-.. class:: ArrayOrContainerT
+.. autodata:: ArrayOrContainerT
 
     A type variable with a lower bound of :class:`ArrayOrContainer`.
 
-.. class:: ArrayOrContainerOrScalar
+.. autodata:: ArrayOrContainerOrScalar
 
-.. class:: ArrayOrContainerOrScalarT
+.. autodata:: ArrayOrContainerOrScalarT
 
     A type variable with a lower bound of :class:`ArrayOrContainerOrScalar`.
 
-Internal typing helpers (do not import)
----------------------------------------
-
 .. currentmodule:: arraycontext.context
 
-This is only here because the documentation tool wants it.
-
-.. class:: SelfType
-
 Canonical locations for type annotations
 ----------------------------------------
 
@@ -176,16 +169,11 @@ from pytools.tag import ToTagSetConvertible
 if TYPE_CHECKING:
     import loopy
 
-    from arraycontext.container import ArrayContainer
+    from arraycontext.container import ArithArrayContainer, ArrayContainer
 
 
 # {{{ typing
 
-ScalarLike = int | float | complex | np.generic
-
-SelfType = TypeVar("SelfType")
-
-
 class Array(Protocol):
     """A :class:`~typing.Protocol` for the array type supported by
     :class:`ArrayContext`.
@@ -236,16 +224,26 @@ class Array(Protocol):
 # deprecated, use ScalarLike instead
 ScalarLike: TypeAlias = int | float | complex | np.generic
 Scalar = ScalarLike
-
+ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)
 
 ArrayT = TypeVar("ArrayT", bound=Array)
 ArrayOrScalar: TypeAlias = "Array | ScalarLike"
 ArrayOrContainer: TypeAlias = "Array | ArrayContainer"
+ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer"
 ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer)
+ArrayOrArithContainerT = TypeVar("ArrayOrArithContainerT", bound=ArrayOrArithContainer)
 ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike"
+ArrayOrArithContainerOrScalar: TypeAlias = "Array | ArithArrayContainer | ScalarLike"
 ArrayOrContainerOrScalarT = TypeVar(
         "ArrayOrContainerOrScalarT",
         bound=ArrayOrContainerOrScalar)
+ArrayOrArithContainerOrScalarT = TypeVar(
+        "ArrayOrArithContainerOrScalarT",
+        bound=ArrayOrContainerOrScalar)
+
+
+ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike")
+
 
 NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike]
 
@@ -494,7 +492,7 @@ class ArrayContext(ABC):
         return self.tag(tagged, out_ary)
 
     @abstractmethod
-    def clone(self: SelfType) -> SelfType:
+    def clone(self) -> Self:
         """If possible, return a version of *self* that is semantically
         equivalent (i.e. implements all array operations in the same way)
         but is a separate object. May return *self* if that is not possible.
diff --git a/doc/conf.py b/doc/conf.py
index 0ba4930..0042ae5 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -38,3 +38,8 @@ import sys
 
 
 sys._BUILDING_SPHINX_DOCS = True
+
+
+nitpick_ignore_regex = [
+    ["py:class", r"arraycontext\.context\.ContainerOrScalarT"],
+    ]
diff --git a/pyproject.toml b/pyproject.toml
index d715981..2e51586 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,8 @@ dependencies = [
     "immutabledict>=4.1",
     "numpy",
     "pytools>=2024.1.3",
+    # for Self
+    "typing_extensions>=4",
 ]
 
 [project.optional-dependencies]
-- 
GitLab