From 3e211cc3c73cfcbe107fb93daf0fc621228cb2d8 Mon Sep 17 00:00:00 2001 From: Alex Fikl <alexfikl@gmail.com> Date: Mon, 7 Jun 2021 15:12:14 -0500 Subject: [PATCH] Enable mypy (#18) * enable mypy in setup.cfg * add some more type annotations * add py.typed * add mypy to ci * fix some type annotations --- .github/workflows/ci.yml | 17 ++++++++ .gitlab-ci.yml | 12 +++++ arraycontext/container/__init__.py | 23 +++++++--- arraycontext/container/arithmetic.py | 49 ++++++++++++++------- arraycontext/container/dataclass.py | 7 +-- arraycontext/container/traversal.py | 65 ++++++++++++++++++---------- arraycontext/py.typed | 0 doc/Makefile | 2 +- setup.cfg | 24 +++++++++- setup.py | 1 + 10 files changed, 150 insertions(+), 50 deletions(-) create mode 100644 arraycontext/py.typed diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 76e8aa8..c9bfee5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,6 +34,23 @@ jobs: curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh . ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" examples/*.py test/test_*.py + mypy: + name: Mypy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - + uses: actions/setup-python@v1 + with: + python-version: '3.x' + - name: "Main Script" + run: | + curl -L -O https://tiker.net/ci-support-v0 + . ./ci-support-v0 + build_py_project_in_conda_env + python -m pip install mypy + python -m mypy "$(basename $GITHUB_REPOSITORY)" test + pytest3: name: Pytest Conda Py3 runs-on: ubuntu-latest diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d93ad5e..e596fd7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -89,3 +89,15 @@ Pylint: - python3 except: - tags + +Mypy: + script: | + curl -L -O https://tiker.net/ci-support-v0 + . ./ci-support-v0 + build_py_project_in_venv + python -m pip install mypy + python -m mypy "$CI_PROJECT_NAME" test + tags: + - python3 + except: + - tags diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index a3398f3..397af49 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -1,6 +1,12 @@ +# mypy: disallow-untyped-defs + """ .. currentmodule:: arraycontext +.. class:: ArrayContainerT + :canonical: arraycontext.container.ArrayContainerT + + :class:`~typing.TypeVar` for array container-like objects. .. autoclass:: ArrayContainer @@ -43,9 +49,11 @@ THE SOFTWARE. from functools import singledispatch from arraycontext.context import ArrayContext -from typing import Any, Iterable, Tuple, Optional +from typing import Any, Iterable, Tuple, TypeVar, Optional import numpy as np +ArrayContainerT = TypeVar("ArrayContainerT") + # {{{ ArrayContainer @@ -111,7 +119,7 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]: @singledispatch -def deserialize_container(template, iterable: Iterable[Tuple[Any, Any]]): +def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any: """Deserialize an iterable into an array container. :param template: an instance of an existing object that @@ -131,7 +139,7 @@ def is_array_container_type(cls: type) -> bool: return ( cls is ArrayContainer or (serialize_container.dispatch(cls) - is not serialize_container.__wrapped__)) + is not serialize_container.__wrapped__)) # type: ignore def is_array_container(ary: Any) -> bool: @@ -140,11 +148,11 @@ def is_array_container(ary: Any) -> bool: :func:`serialize_container`. """ return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) + is not serialize_container.__wrapped__) # type: ignore @singledispatch -def get_container_context(ary: ArrayContainer) -> Optional["ArrayContext"]: +def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: """Retrieves the :class:`ArrayContext` from the container, if any. This function is not recursive, so it will only search at the root level @@ -169,7 +177,8 @@ def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: @deserialize_container.register(np.ndarray) def _deserialize_ndarray_container( - template: Any, iterable: Iterable[Tuple[Any, Any]]): + template: np.ndarray, + iterable: Iterable[Tuple[Any, Any]]) -> np.ndarray: # disallow subclasses assert type(template) is np.ndarray assert template.dtype.char == "O" @@ -185,7 +194,7 @@ def _deserialize_ndarray_container( # {{{ get_container_context_recursively -def get_container_context_recursively(ary) -> Optional["ArrayContext"]: +def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: """Walks the :class:`ArrayContainer` hierarchy to find an :class:`ArrayContext` associated with it. diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 02f6692..db989c1 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -1,3 +1,5 @@ +# mypy: disallow-untyped-defs + """ .. currentmodule:: arraycontext .. autofunction:: with_container_arithmetic @@ -29,12 +31,16 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import Any, Callable, Optional, Tuple, TypeVar, Union import numpy as np # {{{ with_container_arithmetic +T = TypeVar("T") + + class _OpClass(enum.Enum): ARITHMETIC = enum.auto MATMUL = enum.auto @@ -79,7 +85,7 @@ _BINARY_OP_AND_DUNDER = [ ] -def _format_unary_op_str(op_str, arg1): +def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: if isinstance(arg1, tuple): arg1_entry, arg1_container = arg1 return (f"{op_str.format(arg1_entry)} " @@ -88,7 +94,9 @@ def _format_unary_op_str(op_str, arg1): return op_str.format(arg1) -def _format_binary_op_str(op_str, arg1, arg2): +def _format_binary_op_str(op_str: str, + arg1: Union[Tuple[str, ...], str], + arg2: Union[Tuple[str, ...], str]) -> str: if isinstance(arg1, tuple) and isinstance(arg2, tuple): import sys if sys.version_info >= (3, 10): @@ -115,11 +123,18 @@ def _format_binary_op_str(op_str, arg1, arg2): return op_str.format(arg1, arg2) -def with_container_arithmetic(*, - bcast_number=True, bcast_obj_array=None, bcast_numpy_array=False, - bcast_container_types=None, - arithmetic=True, matmul=False, bitwise=False, shift=False, - eq_comparison=None, rel_comparison=None): +def with_container_arithmetic( + *, + bcast_number: bool = True, + bcast_obj_array: Optional[bool] = None, + bcast_numpy_array: bool = False, + bcast_container_types: Optional[Tuple[type, ...]] = None, + arithmetic: bool = True, + matmul: bool = False, + bitwise: bool = False, + shift: bool = False, + eq_comparison: Optional[bool] = None, + rel_comparison: Optional[bool] = None) -> Callable[[type], type]: """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. @@ -188,17 +203,18 @@ def with_container_arithmetic(*, raise TypeError("bcast_obj_array must be set if bcast_numpy_array is") if bcast_numpy_array: - def numpy_pred(name): + def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray)" elif bcast_obj_array: - def numpy_pred(name): + def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" else: - def numpy_pred(name): + def numpy_pred(name: str) -> str: return "False" # optimized away if bcast_container_types is None: bcast_container_types = () + bcast_container_types_count = len(bcast_container_types) if np.ndarray in bcast_container_types and bcast_obj_array: raise ValueError("If numpy.ndarray is part of bcast_container_types, " @@ -220,7 +236,7 @@ def with_container_arithmetic(*, # }}} - def wrap(cls): + def wrap(cls: Any) -> Any: if (not hasattr(cls, "_serialize_init_arrays_code") or not hasattr(cls, "_deserialize_init_arrays_code")): raise TypeError(f"class '{cls.__name__}' must provide serialization " @@ -242,16 +258,17 @@ def with_container_arithmetic(*, for i, bct in enumerate(bcast_container_types): gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") - outer_bcast_type_names = [ - f"_bctype{i}" for i in range(len(bcast_container_types))] + outer_bcast_type_names = tuple([ + f"_bctype{i}" for i in range(bcast_container_types_count) + ]) if bcast_number: - outer_bcast_type_names.append("Number") + outer_bcast_type_names += ("Number",) - def same_key(k1, k2): + def same_key(k1: T, k2: T) -> T: assert k1 == k2 return k1 - def tup_str(t): + def tup_str(t: Tuple[str, ...]) -> str: if not t: return "()" else: diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index f30f3e4..37230e4 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -1,3 +1,5 @@ +# mypy: disallow-untyped-defs + """ .. currentmodule:: arraycontext .. autofunction:: dataclass_array_container @@ -28,14 +30,13 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from dataclasses import fields from arraycontext.container import is_array_container_type # {{{ dataclass containers -def dataclass_array_container(cls): +def dataclass_array_container(cls: type) -> type: """A class decorator that makes the class to which it is applied a :class:`ArrayContainer` by registering appropriate implementations of :func:`serialize_container` and :func:`deserialize_container`. @@ -79,7 +80,7 @@ def dataclass_array_container(cls): from arraycontext import serialize_container, deserialize_container @serialize_container.register(cls) - def _serialize_{lower_cls_name}(ary: cls): + def _serialize_{lower_cls_name}(ary: cls) -> Iterable[Tuple[Any, Any]]: return ({serialize_expr},) @deserialize_container.register(cls) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index fce37a6..3a4331a 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -1,3 +1,5 @@ +# mypy: disallow-untyped-defs + """ .. currentmodule:: arraycontext @@ -46,18 +48,24 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable +from typing import Any, Callable, List, Optional from functools import update_wrapper, partial, singledispatch import numpy as np -from arraycontext.container import (is_array_container, +from arraycontext.context import ArrayContext +from arraycontext.container import ( + ArrayContainerT, is_array_container, serialize_container, deserialize_container) # {{{ array container traversal -def _map_array_container_impl(f, ary, *, leaf_cls=None, recursive=False): +def _map_array_container_impl( + f: Callable[[Any], Any], + ary: ArrayContainerT, *, + leaf_cls: Optional[type] = None, + recursive: bool = False) -> ArrayContainerT: """Helper for :func:`rec_map_array_container`. :param leaf_cls: class on which we call *f* directly. This is mostly @@ -65,7 +73,7 @@ def _map_array_container_impl(f, ary, *, leaf_cls=None, recursive=False): specific container classes. By default, the recursion is stopped when a non-:class:`ArrayContainer` class is encountered. """ - def rec(_ary): + def rec(_ary: ArrayContainerT) -> ArrayContainerT: if type(_ary) is leaf_cls: # type(ary) is never None return f(_ary) elif is_array_container(_ary): @@ -79,7 +87,11 @@ def _map_array_container_impl(f, ary, *, leaf_cls=None, recursive=False): return rec(ary) -def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False): +def _multimap_array_container_impl( + f: Callable[..., Any], + *args: Any, + leaf_cls: Optional[type] = None, + recursive: bool = False) -> ArrayContainerT: """Helper for :func:`rec_multimap_array_container`. :param leaf_cls: class on which we call *f* directly. This is mostly @@ -87,7 +99,7 @@ def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False): specific container classes. By default, the recursion is stopped when a non-:class:`ArrayContainer` class is encountered. """ - def rec(*_args): + def rec(*_args: Any) -> Any: template_ary = _args[container_indices[0]] assert all( type(_args[i]) is type(template_ary) for i in container_indices[1:] @@ -112,11 +124,11 @@ def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False): new_args[i] = subary - result.append((key, frec(*new_args))) + result.append((key, frec(*new_args))) # type: ignore return deserialize_container(template_ary, result) - container_indices = [ + container_indices: List[int] = [ i for i, arg in enumerate(args) if is_array_container(arg) and type(arg) is not leaf_cls] @@ -126,21 +138,24 @@ def _multimap_array_container_impl(f, *args, leaf_cls=None, recursive=False): if len(container_indices) == 1: # NOTE: if we just have one ArrayContainer in args, passing it through # _map_array_container_impl should be faster - def wrapper(ary): + def wrapper(ary: ArrayContainerT) -> ArrayContainerT: new_args = list(args) new_args[container_indices[0]] = ary return f(*new_args) update_wrapper(wrapper, f) + template_ary: ArrayContainerT = args[container_indices[0]] return _map_array_container_impl( - wrapper, args[container_indices[0]], + wrapper, template_ary, leaf_cls=leaf_cls, recursive=recursive) frec = rec if recursive else f return rec(*args) -def map_array_container(f: Callable[[Any], Any], ary): +def map_array_container( + f: Callable[[Any], Any], + ary: ArrayContainerT) -> ArrayContainerT: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but @@ -159,7 +174,7 @@ def map_array_container(f: Callable[[Any], Any], ary): return f(ary) -def multimap_array_container(f: Callable[[Any], Any], *args): +def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: r"""Applies *f* to the components of multiple :class:`ArrayContainer`\ s. Works similarly to :func:`~pytools.obj_array.obj_array_vectorize_n_args`, @@ -174,7 +189,9 @@ def multimap_array_container(f: Callable[[Any], Any], *args): return _multimap_array_container_impl(f, *args, recursive=False) -def rec_map_array_container(f: Callable[[Any], Any], ary): +def rec_map_array_container( + f: Callable[[Any], Any], + ary: ArrayContainerT) -> ArrayContainerT: r"""Applies *f* recursively to an :class:`ArrayContainer`. For a non-recursive version see :func:`map_array_container`. @@ -185,14 +202,15 @@ def rec_map_array_container(f: Callable[[Any], Any], ary): return _map_array_container_impl(f, ary, recursive=True) -def mapped_over_array_containers(f: Callable[[Any], Any]): +def mapped_over_array_containers( + f: Callable[[Any], Any]) -> Callable[[ArrayContainerT], ArrayContainerT]: """Decorator around :func:`rec_map_array_container`.""" wrapper = partial(rec_map_array_container, f) update_wrapper(wrapper, f) return wrapper -def rec_multimap_array_container(f: Callable[[Any], Any], *args): +def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s. For a non-recursive version see :func:`multimap_array_container`. @@ -203,11 +221,12 @@ def rec_multimap_array_container(f: Callable[[Any], Any], *args): return _multimap_array_container_impl(f, *args, recursive=True) -def multimapped_over_array_containers(f: Callable[[Any], Any]): +def multimapped_over_array_containers( + f: Callable[..., Any]) -> Callable[..., Any]: """Decorator around :func:`rec_multimap_array_container`.""" # can't use functools.partial, because its result is insufficiently # function-y to be used as a method definition. - def wrapper(*args): + def wrapper(*args: Any) -> Any: return rec_multimap_array_container(f, *args) update_wrapper(wrapper, f) @@ -219,7 +238,9 @@ def multimapped_over_array_containers(f: Callable[[Any], Any]): # {{{ freeze/thaw @singledispatch -def freeze(ary, actx=None): +def freeze( + ary: ArrayContainerT, + actx: Optional[ArrayContext] = None) -> ArrayContainerT: r"""Freezes recursively by going through all components of the :class:`ArrayContainer` *ary*. @@ -243,7 +264,7 @@ def freeze(ary, actx=None): @singledispatch -def thaw(ary, actx): +def thaw(ary: ArrayContainerT, actx: ArrayContext) -> ArrayContainerT: r"""Thaws recursively by going through all components of the :class:`ArrayContainer` *ary*. @@ -276,13 +297,13 @@ def thaw(ary, actx): # {{{ numpy conversion -def from_numpy(ary, actx): +def from_numpy(ary: Any, actx: ArrayContext) -> Any: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ - def _from_numpy(subary): + def _from_numpy(subary: Any) -> Any: if isinstance(subary, np.ndarray) and subary.dtype != "O": return actx.from_numpy(subary) elif is_array_container(subary): @@ -293,7 +314,7 @@ def from_numpy(ary, actx): return _from_numpy(ary) -def to_numpy(ary, actx): +def to_numpy(ary: Any, actx: ArrayContext) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. diff --git a/arraycontext/py.typed b/arraycontext/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/doc/Makefile b/doc/Makefile index d4bb2cb..d0ac5f2 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -4,7 +4,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build +SPHINXBUILD ?= python $(shell which sphinx-build) SOURCEDIR = . BUILDDIR = _build diff --git a/setup.cfg b/setup.cfg index f1124f0..183d7a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,31 @@ [flake8] +min_python_version = 3.6 ignore = E126,E127,E128,E123,E226,E241,E242,E265,W503,E402 max-line-length=85 inline-quotes = " docstring-quotes = """ multiline-quotes = """ - # enable-flake8-bugbear + +[mypy] +python_version = 3.6 +warn_unused_ignores = True + +[mypy-islpy] +ignore_missing_imports = True + +[mypy-loopy.*] +ignore_missing_imports = True + +[mypy-numpy] +ignore_missing_imports = True + +[mypy-meshmode.*] +ignore_missing_imports = True + +[mypy-pymbolic.*] +ignore_missing_imports = True + +[mypy-pyopencl.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index c6aa859..074d5cf 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ def main(): "loopy>=2019.1", "dataclasses; python_version<='3.6'", ], + package_data={"arraycontext": ["py.typed"]}, ) -- GitLab