From 25b3af32a273b4a8b59e81f3051c03c4ee9c7e4a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 27 Nov 2024 10:55:36 -0600 Subject: [PATCH] Modernize type annotations, require __future__ annotations --- arraycontext/__init__.py | 1 + arraycontext/container/dataclass.py | 1 + arraycontext/context.py | 20 ++++++++------ arraycontext/fake_numpy.py | 3 +++ arraycontext/impl/__init__.py | 3 +++ arraycontext/impl/jax/__init__.py | 2 ++ arraycontext/impl/jax/fake_numpy.py | 3 +++ arraycontext/impl/numpy/fake_numpy.py | 3 +++ arraycontext/impl/pyopencl/fake_numpy.py | 3 +++ .../impl/pyopencl/taggable_cl_array.py | 7 ++--- arraycontext/impl/pytato/compile.py | 17 +++++++----- arraycontext/impl/pytato/fake_numpy.py | 3 +++ arraycontext/impl/pytato/utils.py | 5 +++- arraycontext/loopy.py | 2 ++ arraycontext/metadata.py | 1 + arraycontext/pytest.py | 2 ++ arraycontext/transform_metadata.py | 2 ++ arraycontext/version.py | 2 ++ doc/make_numpy_coverage_table.py | 1 + pyproject.toml | 6 +++-- test/test_arraycontext.py | 12 +++++---- test/test_pytato_arraycontext.py | 2 ++ test/test_utils.py | 27 ++++++++++++------- 23 files changed, 92 insertions(+), 36 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index c40117e..1c2ae45 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -2,6 +2,7 @@ An array context is an abstraction that helps you dispatch between multiple implementations of :mod:`numpy`-like :math:`n`-dimensional arrays. """ +from __future__ import annotations __copyright__ = """ diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index ae9ab48..9495b34 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -4,6 +4,7 @@ .. currentmodule:: arraycontext .. autofunction:: dataclass_array_container """ +from __future__ import annotations __copyright__ = """ diff --git a/arraycontext/context.py b/arraycontext/context.py index 398f8aa..f6dc70b 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -134,6 +134,9 @@ Canonical locations for type annotations :canonical: arraycontext.ArrayOrContainerOrScalarT """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -160,7 +163,7 @@ THE SOFTWARE. from abc import ABC, abstractmethod from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, Union +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union from warnings import warn import numpy as np @@ -204,14 +207,14 @@ class Array(Protocol): ... @property - def dtype(self) -> "np.dtype[Any]": + def dtype(self) -> np.dtype[Any]: ... # Covering all the possible index variations is hard and (kind of) futile. # If you'd like to see how, try changing the Any to # AxisIndex = slice | int | "Array" # Index = AxisIndex |tuple[AxisIndex] - def __getitem__(self, index: Any) -> "Array": + def __getitem__(self, index: Any) -> Array: ... @@ -220,9 +223,10 @@ Scalar = ScalarLike ArrayT = TypeVar("ArrayT", bound=Array) -ArrayOrContainer = Union[Array, "ArrayContainer"] +ArrayOrScalar: TypeAlias = "Array | ScalarLike" +ArrayOrContainer: TypeAlias = "Array | ArrayContainer" ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) -ArrayOrContainerOrScalar = Union[Array, "ArrayContainer", ScalarLike] +ArrayOrContainerOrScalar: TypeAlias = "Array | ArrayContainer | ScalarLike" ArrayOrContainerOrScalarT = TypeVar( "ArrayOrContainerOrScalarT", bound=ArrayOrContainerOrScalar) @@ -295,7 +299,7 @@ class ArrayContext(ABC): def zeros(self, shape: int | tuple[int, ...], - dtype: "np.dtype[Any]") -> Array: + dtype: np.dtype[Any]) -> Array: warn(f"{type(self).__name__}.zeros is deprecated and will stop " "working in 2025. Use actx.np.zeros instead.", DeprecationWarning, stacklevel=2) @@ -329,7 +333,7 @@ class ArrayContext(ABC): @abstractmethod def call_loopy(self, - t_unit: "loopy.TranslationUnit", + t_unit: loopy.TranslationUnit, **kwargs: Any) -> dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. @@ -414,7 +418,7 @@ class ArrayContext(ABC): @memoize_method def _get_einsum_prg(self, spec: str, arg_names: tuple[str, ...], - tagged: ToTagSetConvertible) -> "loopy.TranslationUnit": + tagged: ToTagSetConvertible) -> loopy.TranslationUnit: import loopy as lp from loopy.version import MOST_RECENT_LANGUAGE_VERSION diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 5821561..6c5fb15 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/__init__.py b/arraycontext/impl/__init__.py index ac0e47a..53030a2 100644 --- a/arraycontext/impl/__init__.py +++ b/arraycontext/impl/__init__.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 0b6cd72..a70cbaa 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -2,6 +2,8 @@ .. currentmodule:: arraycontext .. autoclass:: EagerJAXArrayContext """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 094e8cf..1a4e790 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index f345edc..582ccda 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index ae340ca..4b96e47 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -2,6 +2,9 @@ .. currentmodule:: arraycontext .. autoclass:: PyOpenCLArrayContext """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 7de7611..39f9258 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -4,6 +4,7 @@ .. autofunction:: to_tagged_cl_array """ +from __future__ import annotations from dataclasses import dataclass from typing import Any @@ -25,7 +26,7 @@ class Axis(Taggable): tags: frozenset[Tag] - def _with_new_tags(self, tags: frozenset[Tag]) -> "Axis": + def _with_new_tags(self, tags: frozenset[Tag]) -> Axis: from dataclasses import replace return replace(self, tags=tags) @@ -109,12 +110,12 @@ class TaggableCLArray(cla.Array, Taggable): return type(self)(None, tags=self.tags, axes=self.axes, **_unwrap_cl_array(ary)) - def _with_new_tags(self, tags: frozenset[Tag]) -> "TaggableCLArray": + def _with_new_tags(self, tags: frozenset[Tag]) -> TaggableCLArray: return type(self)(None, tags=tags, axes=self.axes, **_unwrap_cl_array(self)) def with_tagged_axis(self, iaxis: int, - tags: ToTagSetConvertible) -> "TaggableCLArray": + tags: ToTagSetConvertible) -> TaggableCLArray: """ Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. """ diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 952761b..e77c109 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -5,6 +5,9 @@ .. autoclass:: CompiledFunction .. autoclass:: FromArrayContextCompile """ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -261,7 +264,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor], - "CompiledFunction"] = field(default_factory=lambda: {}) + CompiledFunction] = field(default_factory=lambda: {}) # {{{ abstract interface @@ -270,11 +273,11 @@ class BaseLazilyCompilingFunctionCaller: @property def compiled_function_returning_array_container_class( - self) -> type["CompiledFunction"]: + self) -> type[CompiledFunction]: raise NotImplementedError @property - def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type[CompiledFunction]: raise NotImplementedError # }}} @@ -383,11 +386,11 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> type["CompiledFunction"]: + self) -> type[CompiledFunction]: return CompiledPyOpenCLFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type[CompiledFunction]: return CompiledPyOpenCLFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): @@ -482,11 +485,11 @@ class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( - self) -> type["CompiledFunction"]: + self) -> type[CompiledFunction]: return CompiledJAXFunctionReturningArrayContainer @property - def compiled_function_returning_array_class(self) -> type["CompiledFunction"]: + def compiled_function_returning_array_class(self) -> type[CompiledFunction]: return CompiledJAXFunctionReturningArray def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 0692eb7..d707285 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 2d624d9..c031e29 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __doc__ = """ .. autofunction:: transfer_from_numpy .. autofunction:: transfer_to_numpy @@ -127,7 +130,7 @@ class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget): self.limit_arg_size_nbytes = limit_arg_size_nbytes @memoize_method - def get_loopy_target(self) -> "lp.PyOpenCLTarget": + def get_loopy_target(self) -> lp.PyOpenCLTarget: from loopy import PyOpenCLTarget return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index da71784..d6f9078 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -2,6 +2,8 @@ .. currentmodule:: arraycontext .. autofunction:: make_loopy_program """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees diff --git a/arraycontext/metadata.py b/arraycontext/metadata.py index 756999f..5f0633f 100644 --- a/arraycontext/metadata.py +++ b/arraycontext/metadata.py @@ -1,6 +1,7 @@ """ .. autoclass:: NameHint """ +from __future__ import annotations __copyright__ = """ diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index f1f62a7..760fc10 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -6,6 +6,8 @@ .. autofunction:: pytest_generate_tests_for_array_contexts """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees diff --git a/arraycontext/transform_metadata.py b/arraycontext/transform_metadata.py index 2e0942e..ccfcfba 100644 --- a/arraycontext/transform_metadata.py +++ b/arraycontext/transform_metadata.py @@ -4,6 +4,8 @@ .. autoclass:: CommonSubexpressionTag .. autoclass:: ElementwiseMapKernelTag """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees diff --git a/arraycontext/version.py b/arraycontext/version.py index d33045f..90305a2 100644 --- a/arraycontext/version.py +++ b/arraycontext/version.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from importlib import metadata diff --git a/doc/make_numpy_coverage_table.py b/doc/make_numpy_coverage_table.py index 19d09d4..1a5782e 100644 --- a/doc/make_numpy_coverage_table.py +++ b/doc/make_numpy_coverage_table.py @@ -13,6 +13,7 @@ Workflow: python make_numpy_support_table.py numpy_coverage.rst """ +from __future__ import annotations import pathlib diff --git a/pyproject.toml b/pyproject.toml index 0daaa21..a9c1df4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,8 +78,6 @@ extend-ignore = [ "E221", # multiple spaces before operator "E226", # missing whitespace around arithmetic operator "E402", # module-level import not at top of file - "UP006", # updated annotations due to __future__ import - "UP007", # updated annotations due to __future__ import ] [tool.ruff.lint.flake8-quotes] @@ -101,6 +99,10 @@ known-local-folder = [ "arraycontext", ] lines-after-imports = 2 +required-imports = ["from __future__ import annotations"] + +[tool.ruff.lint.per-file-ignores] +"doc/conf.py" = ["I002"] [tool.mypy] python_version = "3.10" diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 47d8e94..050bfc8 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + __copyright__ = "Copyright (C) 2020-21 University of Illinois Board of Trustees" __license__ = """ @@ -23,7 +26,6 @@ THE SOFTWARE. import logging from dataclasses import dataclass from functools import partial -from typing import Union import numpy as np import pytest @@ -216,9 +218,9 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: @dataclass(frozen=True) class MyContainer: name: str - mass: Union[DOFArray, np.ndarray] + mass: DOFArray | np.ndarray momentum: np.ndarray - enthalpy: Union[DOFArray, np.ndarray] + enthalpy: DOFArray | np.ndarray __array_ufunc__ = None @@ -241,9 +243,9 @@ class MyContainer: @dataclass(frozen=True) class MyContainerDOFBcast: name: str - mass: Union[DOFArray, np.ndarray] + mass: DOFArray | np.ndarray momentum: np.ndarray - enthalpy: Union[DOFArray, np.ndarray] + enthalpy: DOFArray | np.ndarray @property def array_context(self): diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index a14df50..a405038 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -1,4 +1,6 @@ """ PytatoArrayContext specific tests""" +from __future__ import annotations + __copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees" diff --git a/test/test_utils.py b/test/test_utils.py index db9ed82..a6aa271 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,7 @@ """Testing for internal utilities.""" +from __future__ import annotations + +from typing import cast __copyright__ = "Copyright (C) 2021 University of Illinois Board of Trustees" @@ -49,7 +52,6 @@ def test_pt_actx_key_stringification_uniqueness(): def test_dataclass_array_container() -> None: from dataclasses import dataclass, field - from typing import Optional, Tuple # noqa: UP035 from arraycontext import Array, dataclass_array_container @@ -58,7 +60,7 @@ def test_dataclass_array_container() -> None: @dataclass class ArrayContainerWithStringTypes: x: np.ndarray - y: "np.ndarray" + y: np.ndarray with pytest.raises(TypeError, match="String annotation on field 'y'"): # NOTE: cannot have string annotations in container @@ -71,7 +73,7 @@ def test_dataclass_array_container() -> None: @dataclass class ArrayContainerWithOptional: x: np.ndarray - y: Optional[np.ndarray] + y: np.ndarray | None with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`) @@ -84,7 +86,7 @@ def test_dataclass_array_container() -> None: @dataclass class ArrayContainerWithTuple: x: Array - y: Tuple[Array, Array] + y: tuple[Array, Array] with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"): dataclass_array_container(ArrayContainerWithTuple) @@ -131,7 +133,6 @@ def test_dataclass_array_container() -> None: def test_dataclass_container_unions() -> None: from dataclasses import dataclass - from typing import Union from arraycontext import Array, dataclass_array_container @@ -140,7 +141,7 @@ def test_dataclass_container_unions() -> None: @dataclass class ArrayContainerWithUnion: x: np.ndarray - y: Union[np.ndarray, Array] + y: np.ndarray | Array dataclass_array_container(ArrayContainerWithUnion) @@ -158,7 +159,7 @@ def test_dataclass_container_unions() -> None: @dataclass class ArrayContainerWithWrongUnion: x: np.ndarray - y: Union[np.ndarray, float] + y: np.ndarray | float with pytest.raises(TypeError, match="Field 'y' union contains non-array container"): # NOTE: float is not an ArrayContainer, so y should fail @@ -217,9 +218,15 @@ def test_stringify_array_container_tree() -> None: extent: float rng = np.random.default_rng(seed=42) - a = ArrayWrapper(ary=rng.random(10)) - d = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) - c = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + a = ArrayWrapper(ary=cast(Array, rng.random(10))) + d = SomeContainer( + points=cast(Array, rng.random((2, 10))), + radius=rng.random(), + centers=a) + c = SomeContainer( + points=cast(Array, rng.random((2, 10))), + radius=rng.random(), + centers=a) ary = SomeOtherContainer( disk=d, circle=c, has_disk=True, -- GitLab