diff --git a/arraycontext/context.py b/arraycontext/context.py index cc157a4f8ea1068d0044abb92840b0fb777d0fea..72bf8c7de0bf9ddd8740be3d67b00f952a999faf 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -80,6 +80,19 @@ The interface of an array context .. autoclass:: Array .. autoclass:: Scalar .. autoclass:: ArrayContext + +Internal typing helpers (do not import) +--------------------------------------- + +.. currentmodule:: arraycontext.context + +This is only here because the documentation tool wants it. + +.. class:: SelfType + +.. class:: ArrayT + + A type variable, with a lower bound of :class:`Array`. """ @@ -110,7 +123,7 @@ THE SOFTWARE. from abc import ABC, abstractmethod from typing import ( Any, Callable, Dict, Optional, Tuple, Union, - TYPE_CHECKING) + TYPE_CHECKING, TypeVar) import numpy as np from pytools import memoize_method @@ -129,6 +142,8 @@ try: except ImportError: from typing_extensions import Protocol # type: ignore[misc] +SelfType = TypeVar("SelfType") + class Array(Protocol): """A :class:`~typing.Protocol` for the array type supported by @@ -150,6 +165,9 @@ class Array(Protocol): ... +ArrayT = TypeVar("ArrayT", bound=Array) + + class Scalar(Protocol): """A :class:`~typing.Protocol` for the scalar type supported by :class:`ArrayContext`. @@ -322,7 +340,7 @@ class ArrayContext(ABC): @abstractmethod def tag(self, tags: ToTagSetConvertible, - array: Array) -> Array: + array: ArrayT) -> ArrayT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* with the *tags* applied. *array* itself is not modified. @@ -335,7 +353,7 @@ class ArrayContext(ABC): @abstractmethod def tag_axis(self, iaxis: int, tags: ToTagSetConvertible, - array: Array) -> Array: + array: ArrayT) -> ArrayT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* in which axis number *iaxis* has the *tags* applied. *array* itself is not modified. @@ -406,7 +424,7 @@ class ArrayContext(ABC): return self.tag(tagged, out_ary) @abstractmethod - def clone(self) -> "ArrayContext": + def clone(self: SelfType) -> SelfType: """If possible, return a version of *self* that is semantically equivalent (i.e. implements all array operations in the same way) but is a separate object. May return *self* if that is not possible.