From 3e211cc3c73cfcbe107fb93daf0fc621228cb2d8 Mon Sep 17 00:00:00 2001
From: Alex Fikl <alexfikl@gmail.com>
Date: Mon, 7 Jun 2021 15:12:14 -0500
Subject: [PATCH] Enable mypy (#18)

* enable mypy in setup.cfg

* add some more type annotations

* add py.typed

* add mypy to ci

* fix some type annotations
---
 .github/workflows/ci.yml             | 17 ++++++++
 .gitlab-ci.yml                       | 12 +++++
 arraycontext/container/__init__.py   | 23 +++++++---
 arraycontext/container/arithmetic.py | 49 ++++++++++++++-------
 arraycontext/container/dataclass.py  |  7 +--
 arraycontext/container/traversal.py  | 65 ++++++++++++++++++----------
 arraycontext/py.typed                |  0
 doc/Makefile                         |  2 +-
 setup.cfg                            | 24 +++++++++-
 setup.py                             |  1 +
 10 files changed, 150 insertions(+), 50 deletions(-)
 create mode 100644 arraycontext/py.typed

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 76e8aa8..c9bfee5 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -34,6 +34,23 @@ jobs:
                 curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh
                 . ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" examples/*.py test/test_*.py
 
+    mypy:
+        name: Mypy
+        runs-on: ubuntu-latest
+        steps:
+        -   uses: actions/checkout@v2
+        -
+            uses: actions/setup-python@v1
+            with:
+                python-version: '3.x'
+        -   name: "Main Script"
+            run: |
+                curl -L -O https://tiker.net/ci-support-v0
+                . ./ci-support-v0
+                build_py_project_in_conda_env
+                python -m pip install mypy
+                python -m mypy "$(basename $GITHUB_REPOSITORY)" test
+
     pytest3:
         name: Pytest Conda Py3
         runs-on: ubuntu-latest
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index d93ad5e..e596fd7 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -89,3 +89,15 @@ Pylint:
   - python3
   except:
   - tags
+
+Mypy:
+  script: |
+    curl -L -O https://tiker.net/ci-support-v0
+    . ./ci-support-v0
+    build_py_project_in_venv
+    python -m pip install mypy
+    python -m mypy "$CI_PROJECT_NAME" test
+  tags:
+  - python3
+  except:
+  - tags
diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index a3398f3..397af49 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -1,6 +1,12 @@
+# mypy: disallow-untyped-defs
+
 """
 .. currentmodule:: arraycontext
 
+.. class:: ArrayContainerT
+    :canonical: arraycontext.container.ArrayContainerT
+
+    :class:`~typing.TypeVar` for array container-like objects.
 
 .. autoclass:: ArrayContainer
 
@@ -43,9 +49,11 @@ THE SOFTWARE.
 
 from functools import singledispatch
 from arraycontext.context import ArrayContext
-from typing import Any, Iterable, Tuple, Optional
+from typing import Any, Iterable, Tuple, TypeVar, Optional
 import numpy as np
 
+ArrayContainerT = TypeVar("ArrayContainerT")
+
 
 # {{{ ArrayContainer
 
@@ -111,7 +119,7 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]:
 
 
 @singledispatch
-def deserialize_container(template, iterable: Iterable[Tuple[Any, Any]]):
+def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any:
     """Deserialize an iterable into an array container.
 
     :param template: an instance of an existing object that
@@ -131,7 +139,7 @@ def is_array_container_type(cls: type) -> bool:
     return (
             cls is ArrayContainer
             or (serialize_container.dispatch(cls)
-                is not serialize_container.__wrapped__))
+                is not serialize_container.__wrapped__))    # type: ignore
 
 
 def is_array_container(ary: Any) -> bool:
@@ -140,11 +148,11 @@ def is_array_container(ary: Any) -> bool:
         :func:`serialize_container`.
     """
     return (serialize_container.dispatch(ary.__class__)
-            is not serialize_container.__wrapped__)
+            is not serialize_container.__wrapped__)         # type: ignore
 
 
 @singledispatch
-def get_container_context(ary: ArrayContainer) -> Optional["ArrayContext"]:
+def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
     """Retrieves the :class:`ArrayContext` from the container, if any.
 
     This function is not recursive, so it will only search at the root level
@@ -169,7 +177,8 @@ def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
 
 @deserialize_container.register(np.ndarray)
 def _deserialize_ndarray_container(
-        template: Any, iterable: Iterable[Tuple[Any, Any]]):
+        template: np.ndarray,
+        iterable: Iterable[Tuple[Any, Any]]) -> np.ndarray:
     # disallow subclasses
     assert type(template) is np.ndarray
     assert template.dtype.char == "O"
@@ -185,7 +194,7 @@ def _deserialize_ndarray_container(
 
 # {{{ get_container_context_recursively
 
-def get_container_context_recursively(ary) -> Optional["ArrayContext"]:
+def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
     """Walks the :class:`ArrayContainer` hierarchy to find an
     :class:`ArrayContext` associated with it.
 
diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 02f6692..db989c1 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -1,3 +1,5 @@
+# mypy: disallow-untyped-defs
+
 """
 .. currentmodule:: arraycontext
 .. autofunction:: with_container_arithmetic
@@ -29,12 +31,16 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import Any, Callable, Optional, Tuple, TypeVar, Union
 
 import numpy as np
 
 
 # {{{ with_container_arithmetic
 
+T = TypeVar("T")
+
+
 class _OpClass(enum.Enum):
     ARITHMETIC = enum.auto
     MATMUL = enum.auto
@@ -79,7 +85,7 @@ _BINARY_OP_AND_DUNDER = [
         ]
 
 
-def _format_unary_op_str(op_str, arg1):
+def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
     if isinstance(arg1, tuple):
         arg1_entry, arg1_container = arg1
         return (f"{op_str.format(arg1_entry)} "
@@ -88,7 +94,9 @@ def _format_unary_op_str(op_str, arg1):
         return op_str.format(arg1)
 
 
-def _format_binary_op_str(op_str, arg1, arg2):
+def _format_binary_op_str(op_str: str,
+        arg1: Union[Tuple[str, ...], str],
+        arg2: Union[Tuple[str, ...], str]) -> str:
     if isinstance(arg1, tuple) and isinstance(arg2, tuple):
         import sys
         if sys.version_info >= (3, 10):
@@ -115,11 +123,18 @@ def _format_binary_op_str(op_str, arg1, arg2):
         return op_str.format(arg1, arg2)
 
 
-def with_container_arithmetic(*,
-        bcast_number=True, bcast_obj_array=None, bcast_numpy_array=False,
-        bcast_container_types=None,
-        arithmetic=True, matmul=False, bitwise=False, shift=False,
-        eq_comparison=None, rel_comparison=None):
+def with_container_arithmetic(
+        *,
+        bcast_number: bool = True,
+        bcast_obj_array: Optional[bool] = None,
+        bcast_numpy_array: bool = False,
+        bcast_container_types: Optional[Tuple[type, ...]] = None,
+        arithmetic: bool = True,
+        matmul: bool = False,
+        bitwise: bool = False,
+        shift: bool = False,
+        eq_comparison: Optional[bool] = None,
+        rel_comparison: Optional[bool] = None) -> Callable[[type], type]:
     """A class decorator that implements built-in operators for array containers
     by propagating the operations to the elements of the container.
 
@@ -188,17 +203,18 @@ def with_container_arithmetic(*,
         raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")
 
     if bcast_numpy_array:
-        def numpy_pred(name):
+        def numpy_pred(name: str) -> str:
             return f"isinstance({name}, np.ndarray)"
     elif bcast_obj_array:
-        def numpy_pred(name):
+        def numpy_pred(name: str) -> str:
             return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
     else:
-        def numpy_pred(name):
+        def numpy_pred(name: str) -> str:
             return "False"  # optimized away
 
     if bcast_container_types is None:
         bcast_container_types = ()
+    bcast_container_types_count = len(bcast_container_types)
 
     if np.ndarray in bcast_container_types and bcast_obj_array:
         raise ValueError("If numpy.ndarray is part of bcast_container_types, "
@@ -220,7 +236,7 @@ def with_container_arithmetic(*,
 
     # }}}
 
-    def wrap(cls):
+    def wrap(cls: Any) -> Any:
         if (not hasattr(cls, "_serialize_init_arrays_code")
                 or not hasattr(cls, "_deserialize_init_arrays_code")):
             raise TypeError(f"class '{cls.__name__}' must provide serialization "
@@ -242,16 +258,17 @@ def with_container_arithmetic(*,
             for i, bct in enumerate(bcast_container_types):
                 gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
             gen("")
-        outer_bcast_type_names = [
-                f"_bctype{i}" for i in range(len(bcast_container_types))]
+        outer_bcast_type_names = tuple([
+                f"_bctype{i}" for i in range(bcast_container_types_count)
+                ])
         if bcast_number:
-            outer_bcast_type_names.append("Number")
+            outer_bcast_type_names += ("Number",)
 
-        def same_key(k1, k2):
+        def same_key(k1: T, k2: T) -> T:
             assert k1 == k2
             return k1
 
-        def tup_str(t):
+        def tup_str(t: Tuple[str, ...]) -> str:
             if not t:
                 return "()"
             else:
diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index f30f3e4..37230e4 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -1,3 +1,5 @@
+# mypy: disallow-untyped-defs
+
 """
 .. currentmodule:: arraycontext
 .. autofunction:: dataclass_array_container
@@ -28,14 +30,13 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-
 from dataclasses import fields
 from arraycontext.container import is_array_container_type
 
 
 # {{{ dataclass containers
 
-def dataclass_array_container(cls):
+def dataclass_array_container(cls: type) -> type:
     """A class decorator that makes the class to which it is applied a
     :class:`ArrayContainer` by registering appropriate implementations of
     :func:`serialize_container` and :func:`deserialize_container`.
@@ -79,7 +80,7 @@ def dataclass_array_container(cls):
         from arraycontext import serialize_container, deserialize_container
 
         @serialize_container.register(cls)
-        def _serialize_{lower_cls_name}(ary: cls):
+        def _serialize_{lower_cls_name}(ary: cls) -> Iterable[Tuple[Any, Any]]:
             return ({serialize_expr},)
 
         @deserialize_container.register(cls)
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index fce37a6..3a4331a 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -1,3 +1,5 @@
+# mypy: disallow-untyped-defs
+
 """
 .. currentmodule:: arraycontext
 
@@ -46,18 +48,24 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Any, Callable
+from typing import Any, Callable, List, Optional
 from functools import update_wrapper, partial, singledispatch
 
 import numpy as np
 
-from arraycontext.container import (is_array_container,
+from arraycontext.context import ArrayContext
+from arraycontext.container import (
+        ArrayContainerT, is_array_container,
         serialize_container, deserialize_container)
 
 
 # {{{ array container traversal
 
-def _map_array_container_impl(f, ary, *, leaf_cls=None, recursive=False):
+def _map_array_container_impl(
+        f: Callable[[Any], Any],
+        ary: ArrayContainerT, *,
+        leaf_cls: Optional[type] = None,
+        recursive: bool = False) -> ArrayContainerT:
     """Helper for :func:`rec_map_array_container`.
 
     :param leaf_cls: class on which we call *f* directly. This is mostly
@@ -65,7 +73,7 @@ def _map_array_container_impl(f, ary, *, leaf_cls=None, recursive=False):
         specific container classes. By default, the recursion is stopped when
         a non-:class:`ArrayContainer` class is encountered.
     """
-    def rec(_ary):
+    def rec(_ary: ArrayContainerT) -> ArrayContainerT:
         if type(_ary) is leaf_cls:  # type(ary) is never None
             return f(_ary)
         elif is_array_container(_ary):
@@ -79,7 +87,11 @@ def _map_array_container_impl(f, ary, *, leaf_cls=None, recursive=False):
     return rec(ary)
 
 
-def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False):
+def _multimap_array_container_impl(
+        f: Callable[..., Any],
+        *args: Any,
+        leaf_cls: Optional[type] = None,
+        recursive: bool = False) -> ArrayContainerT:
     """Helper for :func:`rec_multimap_array_container`.
 
     :param leaf_cls: class on which we call *f* directly. This is mostly
@@ -87,7 +99,7 @@ def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False):
         specific container classes. By default, the recursion is stopped when
         a non-:class:`ArrayContainer` class is encountered.
     """
-    def rec(*_args):
+    def rec(*_args: Any) -> Any:
         template_ary = _args[container_indices[0]]
         assert all(
                 type(_args[i]) is type(template_ary) for i in container_indices[1:]
@@ -112,11 +124,11 @@ def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False):
 
                 new_args[i] = subary
 
-            result.append((key, frec(*new_args)))
+            result.append((key, frec(*new_args)))       # type: ignore
 
         return deserialize_container(template_ary, result)
 
-    container_indices = [
+    container_indices: List[int] = [
             i for i, arg in enumerate(args)
             if is_array_container(arg) and type(arg) is not leaf_cls]
 
@@ -126,21 +138,24 @@ def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False):
     if len(container_indices) == 1:
         # NOTE: if we just have one ArrayContainer in args, passing it through
         # _map_array_container_impl should be faster
-        def wrapper(ary):
+        def wrapper(ary: ArrayContainerT) -> ArrayContainerT:
             new_args = list(args)
             new_args[container_indices[0]] = ary
             return f(*new_args)
 
         update_wrapper(wrapper, f)
+        template_ary: ArrayContainerT = args[container_indices[0]]
         return _map_array_container_impl(
-                wrapper, args[container_indices[0]],
+                wrapper, template_ary,
                 leaf_cls=leaf_cls, recursive=recursive)
 
     frec = rec if recursive else f
     return rec(*args)
 
 
-def map_array_container(f: Callable[[Any], Any], ary):
+def map_array_container(
+        f: Callable[[Any], Any],
+        ary: ArrayContainerT) -> ArrayContainerT:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but
@@ -159,7 +174,7 @@ def map_array_container(f: Callable[[Any], Any], ary):
         return f(ary)
 
 
-def multimap_array_container(f: Callable[[Any], Any], *args):
+def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
     r"""Applies *f* to the components of multiple :class:`ArrayContainer`\ s.
 
     Works similarly to :func:`~pytools.obj_array.obj_array_vectorize_n_args`,
@@ -174,7 +189,9 @@ def multimap_array_container(f: Callable[[Any], Any], *args):
     return _multimap_array_container_impl(f, *args, recursive=False)
 
 
-def rec_map_array_container(f: Callable[[Any], Any], ary):
+def rec_map_array_container(
+        f: Callable[[Any], Any],
+        ary: ArrayContainerT) -> ArrayContainerT:
     r"""Applies *f* recursively to an :class:`ArrayContainer`.
 
     For a non-recursive version see :func:`map_array_container`.
@@ -185,14 +202,15 @@ def rec_map_array_container(f: Callable[[Any], Any], ary):
     return _map_array_container_impl(f, ary, recursive=True)
 
 
-def mapped_over_array_containers(f: Callable[[Any], Any]):
+def mapped_over_array_containers(
+        f: Callable[[Any], Any]) -> Callable[[ArrayContainerT], ArrayContainerT]:
     """Decorator around :func:`rec_map_array_container`."""
     wrapper = partial(rec_map_array_container, f)
     update_wrapper(wrapper, f)
     return wrapper
 
 
-def rec_multimap_array_container(f: Callable[[Any], Any], *args):
+def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
     r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
 
     For a non-recursive version see :func:`multimap_array_container`.
@@ -203,11 +221,12 @@ def rec_multimap_array_container(f: Callable[[Any], Any], *args):
     return _multimap_array_container_impl(f, *args, recursive=True)
 
 
-def multimapped_over_array_containers(f: Callable[[Any], Any]):
+def multimapped_over_array_containers(
+        f: Callable[..., Any]) -> Callable[..., Any]:
     """Decorator around :func:`rec_multimap_array_container`."""
     # can't use functools.partial, because its result is insufficiently
     # function-y to be used as a method definition.
-    def wrapper(*args):
+    def wrapper(*args: Any) -> Any:
         return rec_multimap_array_container(f, *args)
 
     update_wrapper(wrapper, f)
@@ -219,7 +238,9 @@ def multimapped_over_array_containers(f: Callable[[Any], Any]):
 # {{{ freeze/thaw
 
 @singledispatch
-def freeze(ary, actx=None):
+def freeze(
+        ary: ArrayContainerT,
+        actx: Optional[ArrayContext] = None) -> ArrayContainerT:
     r"""Freezes recursively by going through all components of the
     :class:`ArrayContainer` *ary*.
 
@@ -243,7 +264,7 @@ def freeze(ary, actx=None):
 
 
 @singledispatch
-def thaw(ary, actx):
+def thaw(ary: ArrayContainerT, actx: ArrayContext) -> ArrayContainerT:
     r"""Thaws recursively by going through all components of the
     :class:`ArrayContainer` *ary*.
 
@@ -276,13 +297,13 @@ def thaw(ary, actx):
 
 # {{{ numpy conversion
 
-def from_numpy(ary, actx):
+def from_numpy(ary: Any, actx: ArrayContext) -> Any:
     """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
     to the base array type of :class:`~arraycontext.ArrayContext`.
 
     The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
     """
-    def _from_numpy(subary):
+    def _from_numpy(subary: Any) -> Any:
         if isinstance(subary, np.ndarray) and subary.dtype != "O":
             return actx.from_numpy(subary)
         elif is_array_container(subary):
@@ -293,7 +314,7 @@ def from_numpy(ary, actx):
     return _from_numpy(ary)
 
 
-def to_numpy(ary, actx):
+def to_numpy(ary: Any, actx: ArrayContext) -> Any:
     """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to
     :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*.
 
diff --git a/arraycontext/py.typed b/arraycontext/py.typed
new file mode 100644
index 0000000..e69de29
diff --git a/doc/Makefile b/doc/Makefile
index d4bb2cb..d0ac5f2 100644
--- a/doc/Makefile
+++ b/doc/Makefile
@@ -4,7 +4,7 @@
 # You can set these variables from the command line, and also
 # from the environment for the first two.
 SPHINXOPTS    ?=
-SPHINXBUILD   ?= sphinx-build
+SPHINXBUILD   ?= python $(shell which sphinx-build)
 SOURCEDIR     = .
 BUILDDIR      = _build
 
diff --git a/setup.cfg b/setup.cfg
index f1124f0..183d7a9 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,9 +1,31 @@
 [flake8]
+min_python_version = 3.6
 ignore = E126,E127,E128,E123,E226,E241,E242,E265,W503,E402
 max-line-length=85
 
 inline-quotes = "
 docstring-quotes = """
 multiline-quotes = """
-
 # enable-flake8-bugbear
+
+[mypy]
+python_version = 3.6
+warn_unused_ignores = True
+
+[mypy-islpy]
+ignore_missing_imports = True
+
+[mypy-loopy.*]
+ignore_missing_imports = True
+
+[mypy-numpy]
+ignore_missing_imports = True
+
+[mypy-meshmode.*]
+ignore_missing_imports = True
+
+[mypy-pymbolic.*]
+ignore_missing_imports = True
+
+[mypy-pyopencl.*]
+ignore_missing_imports = True
diff --git a/setup.py b/setup.py
index c6aa859..074d5cf 100644
--- a/setup.py
+++ b/setup.py
@@ -44,6 +44,7 @@ def main():
             "loopy>=2019.1",
             "dataclasses; python_version<='3.6'",
         ],
+        package_data={"arraycontext": ["py.typed"]},
     )
 
 
-- 
GitLab