From b75ba4f60c601664edc495cc3606fc9f7e47866b Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 5 Apr 2022 10:04:24 -0500
Subject: [PATCH] add typing to ArrayContext

---
 arraycontext/container/traversal.py           |  3 +
 arraycontext/context.py                       | 92 +++++++++++++------
 arraycontext/impl/pyopencl/__init__.py        |  8 +-
 .../impl/pyopencl/taggable_cl_array.py        |  6 +-
 arraycontext/impl/pytato/__init__.py          | 14 +--
 setup.py                                      |  2 +-
 6 files changed, 80 insertions(+), 45 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index de89a6b..cc3bbd5 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -833,6 +833,9 @@ def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
     """
     def _to_numpy_with_check(subary: Any) -> Any:
         if isinstance(subary, actx.array_types) or np.isscalar(subary):
+            # NOTE: these are allowed by np.isscalar, but not here
+            assert not isinstance(subary, (str, bytes))
+
             return actx.to_numpy(subary)
         else:
             raise TypeError(
diff --git a/arraycontext/context.py b/arraycontext/context.py
index d206a87..49e6acc 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -1,3 +1,5 @@
+# mypy: disallow-untyped-defs
+
 """
 .. _freeze-thaw:
 
@@ -105,12 +107,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Sequence, Union, Callable, Any, Tuple
-from abc import ABC, abstractmethod, abstractproperty
+from abc import ABC, abstractmethod
+from typing import (
+        Any, Callable, Dict, Optional, Tuple, Union,
+        TYPE_CHECKING)
 
 import numpy as np
 from pytools import memoize_method
-from pytools.tag import Tag
+from pytools.tag import ToTagSetConvertible
+
+if TYPE_CHECKING:
+    import loopy
 
 
 # {{{ typing
@@ -218,29 +225,35 @@ class ArrayContext(ABC):
 
     array_types: Tuple[type, ...] = ()
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.np = self._get_fake_numpy_namespace()
 
-    def _get_fake_numpy_namespace(self):
+    def _get_fake_numpy_namespace(self) -> Any:
         from .fake_numpy import BaseFakeNumpyNamespace
         return BaseFakeNumpyNamespace(self)
 
     @abstractmethod
-    def empty(self, shape, dtype):
+    def empty(self,
+              shape: Union[int, Tuple[int, ...]],
+              dtype: "np.dtype[Any]") -> Array:
         pass
 
     @abstractmethod
-    def zeros(self, shape, dtype):
+    def zeros(self,
+              shape: Union[int, Tuple[int, ...]],
+              dtype: "np.dtype[Any]") -> Array:
         pass
 
-    def empty_like(self, ary):
+    def empty_like(self, ary: Array) -> Array:
         return self.empty(shape=ary.shape, dtype=ary.dtype)
 
-    def zeros_like(self, ary):
+    def zeros_like(self, ary: Array) -> Array:
         return self.zeros(shape=ary.shape, dtype=ary.dtype)
 
     @abstractmethod
-    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
+    def from_numpy(self,
+                   array: Union["np.ndarray[Any, Any]", _ScalarLike]
+                   ) -> Union[Array, _ScalarLike]:
         r"""
         :returns: the :class:`numpy.ndarray` *array* converted to the
             array context's array type. The returned array will be
@@ -249,7 +262,9 @@ class ArrayContext(ABC):
         pass
 
     @abstractmethod
-    def to_numpy(self, array):
+    def to_numpy(self,
+                 array: Union[Array, _ScalarLike]
+                 ) -> Union["np.ndarray[Any, Any]", _ScalarLike]:
         r"""
         :returns: *array*, an array recognized by the context, converted
             to a :class:`numpy.ndarray`. *array* must be
@@ -257,7 +272,9 @@ class ArrayContext(ABC):
         """
         pass
 
-    def call_loopy(self, program, **kwargs):
+    def call_loopy(self,
+                   program: "loopy.TranslationUnit",
+                   **kwargs: Any) -> Dict[str, Array]:
         """Execute the :mod:`loopy` program *program* on the arguments
         *kwargs*.
 
@@ -270,7 +287,7 @@ class ArrayContext(ABC):
         """
 
     @abstractmethod
-    def freeze(self, array):
+    def freeze(self, array: Array) -> Array:
         """Return a version of the context-defined array *array* that is
         'frozen', i.e. suitable for long-term storage and reuse. Frozen arrays
         do not support arithmetic. For example, in the context of
@@ -286,7 +303,7 @@ class ArrayContext(ABC):
         """
 
     @abstractmethod
-    def thaw(self, array):
+    def thaw(self, array: Array) -> Array:
         """Take a 'frozen' array and return a new array representing the data in
         *array* that is able to perform arithmetic and other operations, using
         the execution resources of this context. In the context of
@@ -301,7 +318,9 @@ class ArrayContext(ABC):
         """
 
     @abstractmethod
-    def tag(self, tags: Union[Sequence[Tag], Tag], array):
+    def tag(self,
+            tags: ToTagSetConvertible,
+            array: Array) -> Array:
         """If the array type used by the array context is capable of capturing
         metadata, return a version of *array* with the *tags* applied. *array*
         itself is not modified.
@@ -310,7 +329,9 @@ class ArrayContext(ABC):
         """
 
     @abstractmethod
-    def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
+    def tag_axis(self,
+                 iaxis: int, tags: ToTagSetConvertible,
+                 array: Array) -> Array:
         """If the array type used by the array context is capable of capturing
         metadata, return a version of *array* in which axis number *iaxis* has
         the *tags* applied. *array* itself is not modified.
@@ -319,7 +340,9 @@ class ArrayContext(ABC):
         """
 
     @memoize_method
-    def _get_einsum_prg(self, spec, arg_names, tagged):
+    def _get_einsum_prg(self,
+                        spec: str, arg_names: Tuple[str, ...],
+                        tagged: ToTagSetConvertible) -> "loopy.TranslationUnit":
         import loopy as lp
         from .loopy import _DEFAULT_LOOPY_OPTIONS
         from loopy.version import MOST_RECENT_LANGUAGE_VERSION
@@ -345,7 +368,10 @@ class ArrayContext(ABC):
     # That's why einsum's interface here needs to be cluttered with
     # metadata, and that's why it can't live under .np.
     # [1] https://github.com/inducer/meshmode/issues/177
-    def einsum(self, spec, *args, arg_names=None, tagged=()):
+    def einsum(self,
+               spec: str, *args: Array,
+               arg_names: Optional[Tuple[str, ...]] = None,
+               tagged: ToTagSetConvertible = ()) -> Array:
         """Computes the result of Einstein summation following the
         convention in :func:`numpy.einsum`.
 
@@ -365,7 +391,7 @@ class ArrayContext(ABC):
         :return: the output of the einsum :mod:`loopy` program
         """
         if arg_names is None:
-            arg_names = tuple("arg%d" % i for i in range(len(args)))
+            arg_names = tuple([f"arg{i}" for i in range(len(args))])
 
         prg = self._get_einsum_prg(spec, arg_names, tagged)
         out_ary = self.call_loopy(
@@ -413,22 +439,28 @@ class ArrayContext(ABC):
         return f
 
     # undocumented for now
-    @abstractproperty
-    def permits_inplace_modification(self):
-        pass
+    @property
+    @abstractmethod
+    def permits_inplace_modification(self) -> bool:
+        """
+        *True* if the arrays allow in-place modifications.
+        """
 
     # undocumented for now
-    @abstractproperty
-    def supports_nonscalar_broadcasting(self):
-        pass
+    @property
+    @abstractmethod
+    def supports_nonscalar_broadcasting(self) -> bool:
+        """
+        *True* if the arrays support non-scalar broadcasting.
+        """
 
-    @abstractproperty
-    def permits_advanced_indexing(self):
+    # undocumented for now
+    @property
+    @abstractmethod
+    def permits_advanced_indexing(self) -> bool:
         """
-        *True* only if the arrays support :mod:`numpy`'s advanced indexing
-        semantics.
+        *True* if the arrays support :mod:`numpy`'s advanced indexing semantics.
         """
-        pass
 
 # }}}
 
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 4e6cc38..ab8f568 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -29,11 +29,11 @@ THE SOFTWARE.
 """
 
 from warnings import warn
-from typing import Dict, List, Sequence, Optional, Union, TYPE_CHECKING
+from typing import Dict, List, Optional, Union, TYPE_CHECKING
 
 import numpy as np
 
-from pytools.tag import Tag
+from pytools.tag import ToTagSetConvertible
 
 from arraycontext.context import ArrayContext, _ScalarLike
 from arraycontext.container.traversal import rec_map_array_container
@@ -301,7 +301,7 @@ class PyOpenCLArrayContext(ArrayContext):
 
         return t_unit
 
-    def tag(self, tags: Union[Sequence[Tag], Tag], array):
+    def tag(self, tags: ToTagSetConvertible, array):
         import pyopencl.array as cl_array
         from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
                                                                   to_tagged_cl_array)
@@ -317,7 +317,7 @@ class PyOpenCLArrayContext(ArrayContext):
 
         return rec_map_array_container(_rec_tagged, array)
 
-    def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
+    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
         import pyopencl.array as cl_array
         from arraycontext.impl.pyopencl.taggable_cl_array import (TaggableCLArray,
                                                                   to_tagged_cl_array)
diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py
index d343447..aa7d12d 100644
--- a/arraycontext/impl/pyopencl/taggable_cl_array.py
+++ b/arraycontext/impl/pyopencl/taggable_cl_array.py
@@ -7,7 +7,7 @@
 
 import pyopencl.array as cla
 from typing import Any, Dict, FrozenSet, Optional, Tuple
-from pytools.tag import Taggable, Tag, TagsType, TagOrIterableType
+from pytools.tag import Taggable, Tag, ToTagSetConvertible
 from dataclasses import dataclass
 from pytools import memoize
 
@@ -93,12 +93,12 @@ class TaggableCLArray(cla.Array, Taggable):
         return type(self)(None, tags=self.tags, axes=self.axes,
                           **_unwrap_cl_array(ary))
 
-    def _with_new_tags(self, tags: TagsType) -> "TaggableCLArray":
+    def _with_new_tags(self, tags: FrozenSet[Tag]) -> "TaggableCLArray":
         return type(self)(None, tags=tags, axes=self.axes,
                           **_unwrap_cl_array(self))
 
     def with_tagged_axis(self, iaxis: int,
-                         tags: TagOrIterableType) -> "TaggableCLArray":
+                         tags: ToTagSetConvertible) -> "TaggableCLArray":
         """
         Returns a copy of *self* with *iaxis*-th axis tagged with *tags*.
         """
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 730133e..a895eb4 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -44,8 +44,8 @@ THE SOFTWARE.
 from arraycontext.context import ArrayContext, _ScalarLike
 from arraycontext.container.traversal import rec_map_array_container
 import numpy as np
-from typing import Any, Callable, Union, Sequence, TYPE_CHECKING
-from pytools.tag import Tag
+from typing import Any, Callable, Union, TYPE_CHECKING
+from pytools.tag import ToTagSetConvertible
 
 if TYPE_CHECKING:
     import pytato
@@ -262,14 +262,14 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
 
         return dag
 
-    def tag(self, tags: Union[Sequence[Tag], Tag], array):
+    def tag(self, tags: ToTagSetConvertible, array):
         return rec_map_array_container(lambda x: x.tagged(tags),
                                        array)
 
-    def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array):
-        return rec_map_array_container(lambda x: x.with_tagged_axis(iaxis,
-                                                                     tags),
-                                        array)
+    def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
+        return rec_map_array_container(
+            lambda x: x.with_tagged_axis(iaxis, tags),
+            array)
 
     def einsum(self, spec, *args, arg_names=None, tagged=()):
         import pyopencl.array as cla
diff --git a/setup.py b/setup.py
index 06e898d..2bc066e 100644
--- a/setup.py
+++ b/setup.py
@@ -41,7 +41,7 @@ def main():
             "numpy",
 
             # https://github.com/inducer/arraycontext/pull/147
-            "pytools>=2022.1.1",
+            "pytools>=2022.1.3",
 
             "pytest>=2.3",
             "loopy>=2019.1",
-- 
GitLab