diff --git a/arraycontext/context.py b/arraycontext/context.py index f6dc70bfe65144c625f184558de6dedca4bc4847..602779700beb276c79991df367296aaf9ed319b5 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -167,6 +167,7 @@ from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union from warnings import warn import numpy as np +from typing_extensions import Self from pytools import memoize_method from pytools.tag import ToTagSetConvertible @@ -196,6 +197,8 @@ class Array(Protocol): .. attribute:: size .. attribute:: dtype .. attribute:: __getitem__ + + In addition, arrays are expected to support basic arithmetic. """ @property @@ -217,8 +220,21 @@ class Array(Protocol): def __getitem__(self, index: Any) -> Array: ... + # some basic arithmetic that's supposed to work + def __neg__(self) -> Self: ... + def __abs__(self) -> Self: ... + def __add__(self, other: Self | ScalarLike) -> Self: ... + def __radd__(self, other: Self | ScalarLike) -> Self: ... + def __sub__(self, other: Self | ScalarLike) -> Self: ... + def __rsub__(self, other: Self | ScalarLike) -> Self: ... + def __mul__(self, other: Self | ScalarLike) -> Self: ... + def __rmul__(self, other: Self | ScalarLike) -> Self: ... + def __truediv__(self, other: Self | ScalarLike) -> Self: ... + def __rtruediv__(self, other: Self | ScalarLike) -> Self: ... + # deprecated, use ScalarLike instead +ScalarLike: TypeAlias = int | float | complex | np.generic Scalar = ScalarLike