diff --git a/pyproject.toml b/pyproject.toml index 7a38ca32b7128de206bc89646b21f044d0f878a7..505eda7de78dc623316b32c3c20a9038f819191d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ license = { text = "MIT" } authors = [ { name = "Andreas Kloeckner", email = "inform@tiker.net" }, ] -requires-python = ">=3.8" +requires-python = ">=3.10" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -68,7 +68,6 @@ pytools = [ ] [tool.ruff] -target-version = "py38" preview = true [tool.ruff.lint] @@ -110,7 +109,7 @@ known-local-folder = [ lines-after-imports = 2 [tool.mypy] -python_version = "3.8" +python_version = "3.10" ignore_missing_imports = true warn_unused_ignores = true # TODO: enable this at some point diff --git a/pytools/__init__.py b/pytools/__init__.py index 9f8113ebb163411a236b8070306aba14317e0325..e96c23d6c123bcc07703530b715564b34d9f2414 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -25,34 +25,21 @@ THE SOFTWARE. import builtins import logging -import math import operator import re import sys +from collections.abc import Callable, Collection, Hashable, Iterable, Mapping, Sequence from functools import reduce, wraps from sys import intern from typing import ( Any, - Callable, ClassVar, - Collection, - Dict, + Concatenate, Generic, - Hashable, - Iterable, - List, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Type, + ParamSpec, TypeVar, - Union, ) -from typing_extensions import Concatenate, ParamSpec, SupportsIndex - # These are deprecated and will go away in 2022. all = builtins.all @@ -67,8 +54,6 @@ Math ---- .. autofunction:: levi_civita -.. autofunction:: perm -.. autofunction:: comb Assertive accessors ------------------- @@ -235,7 +220,7 @@ P = ParamSpec("P") # {{{ code maintenance class MovedFunctionDeprecationWrapper: - def __init__(self, f: F, deadline: Optional[Union[int, str]] = None) -> None: + def __init__(self, f: F, deadline: int | str | None = None) -> None: if deadline is None: deadline = "the future" @@ -252,8 +237,8 @@ class MovedFunctionDeprecationWrapper: def deprecate_keyword(oldkey: str, - newkey: Optional[str] = None, *, - deadline: Optional[str] = None): + newkey: str | None = None, *, + deadline: str | None = None): """Decorator used to deprecate function keyword arguments. :arg oldkey: deprecated argument name. @@ -304,7 +289,7 @@ def delta(x, y): return 0 -def levi_civita(tup: Tuple[int, ...]) -> int: +def levi_civita(tup: tuple[int, ...]) -> int: """Compute an entry of the Levi-Civita symbol for the indices *tuple*.""" if len(tup) == 2: i, j = tup @@ -316,75 +301,6 @@ def levi_civita(tup: Tuple[int, ...]) -> int: raise NotImplementedError(f"Levi-Civita symbol in {len(tup)} dimensions") -factorial = MovedFunctionDeprecationWrapper(math.factorial, deadline=2023) - -try: - # NOTE: only available in python >= 3.8 - perm = MovedFunctionDeprecationWrapper(math.perm, deadline=2023) -except AttributeError: - def _unchecked_perm(n, k): - result = 1 - while k: - result *= n - n -= 1 - k -= 1 - - return result - - def perm(n: SupportsIndex, # type: ignore[misc] - k: Optional[SupportsIndex] = None) -> int: - """ - :returns: :math:`P(n, k)`, the number of permutations of length :math:`k` - drawn from :math:`n` choices. - """ - from warnings import warn - warn("This function is deprecated and will go away in 2023. " - "Use `math.perm` instead, which is available from Python 3.8.", - DeprecationWarning, stacklevel=2) - - if k is None: - return math.factorial(n) - - import operator - n, k = operator.index(n), operator.index(k) - if k > n: - return 0 - - if k < 0: - raise ValueError("k must be a non-negative integer") - - if n < 0: - raise ValueError("n must be a non-negative integer") - - from numbers import Integral - if not isinstance(k, Integral): - raise TypeError(f"'{type(k).__name__}' object cannot be interpreted " - "as an integer") - - if not isinstance(n, Integral): - raise TypeError(f"'{type(n).__name__}' object cannot be interpreted " - "as an integer") - - return _unchecked_perm(n, k) - -try: - # NOTE: only available in python >= 3.8 - comb = MovedFunctionDeprecationWrapper(math.comb, deadline=2023) -except AttributeError: - def comb(n: SupportsIndex, # type: ignore[misc] - k: SupportsIndex) -> int: - """ - :returns: :math:`C(n, k)`, the number of combinations (subsets) - of length :math:`k` drawn from :math:`n` choices. - """ - from warnings import warn - warn("This function is deprecated and will go away in 2023. " - "Use `math.comb` instead, which is available from Python 3.8.", - DeprecationWarning, stacklevel=2) - - return _unchecked_perm(n, k) // math.factorial(k) - - def norm_1(iterable): return sum(abs(x) for x in iterable) @@ -420,12 +336,12 @@ class RecordWithoutPickling: will be individually derived from this class. """ - __slots__: ClassVar[List[str]] = [] - fields: ClassVar[Set[str]] + __slots__: ClassVar[list[str]] = [] + fields: ClassVar[set[str]] def __init__(self, - valuedict: Optional[Mapping[str, Any]] = None, - exclude: Optional[Sequence[str]] = None, + valuedict: Mapping[str, Any] | None = None, + exclude: Sequence[str] | None = None, **kwargs: Any) -> None: assert self.__class__ is not Record @@ -481,7 +397,7 @@ class RecordWithoutPickling: class Record(RecordWithoutPickling): - __slots__: ClassVar[List[str]] = [] + __slots__: ClassVar[list[str]] = [] def __getstate__(self): return { @@ -697,7 +613,7 @@ def memoize(*args: F, **kwargs: Any) -> F: use_kw = bool(kwargs.pop("use_kwargs", False)) - default_key_func: Optional[Callable[..., Any]] + default_key_func: Callable[..., Any] | None if use_kw: def default_key_func(*inner_args, **inner_kwargs): @@ -768,7 +684,7 @@ class _HasKwargs: def memoize_on_first_arg( function: Callable[Concatenate[T, P], R], *, - cache_dict_name: Optional[str] = None) -> Callable[Concatenate[T, P], R]: + cache_dict_name: str | None = None) -> Callable[Concatenate[T, P], R]: """Like :func:`memoize_method`, but for functions that take the object in which do memoization information is stored as first argument. @@ -846,7 +762,7 @@ class keyed_memoize_on_first_arg(Generic[T, P, R]): # noqa: N801 def __init__(self, key: Callable[P, Hashable], *, - cache_dict_name: Optional[str] = None) -> None: + cache_dict_name: str | None = None) -> None: self.key = key self.cache_dict_name = cache_dict_name @@ -1043,7 +959,7 @@ def monkeypatch_class(_name, bases, namespace): # {{{ generic utilities def add_tuples(t1, t2): - return tuple(t1v + t2v for t1v, t2v in zip(t1, t2)) + return tuple(t1v + t2v for t1v, t2v in zip(t1, t2, strict=True)) def negate_tuple(t1): @@ -1091,7 +1007,7 @@ def general_sum(sequence): def linear_combination(coefficients, vectors): result = coefficients[0] * vectors[0] - for c, v in zip(coefficients[1:], vectors[1:]): + for c, v in zip(coefficients[1:], vectors[1:], strict=True): result += c*v return result @@ -1566,7 +1482,7 @@ class Table: .. automethod:: text_without_markup """ - def __init__(self, alignments: Optional[Tuple[str, ...]] = None) -> None: + def __init__(self, alignments: tuple[str, ...] | None = None) -> None: """Create a new :class:`Table`. :arg alignments: A :class:`tuple` of alignments of each column: @@ -1585,7 +1501,7 @@ class Table: alignments = tuple(alignments) - self.rows: List[Tuple[str, ...]] = [] + self.rows: list[tuple[str, ...]] = [] self.alignments = alignments @property @@ -1598,7 +1514,7 @@ class Table: """The number of columns currently in the table.""" return len(self.rows[0]) - def add_row(self, row: Tuple[Any, ...]) -> None: + def add_row(self, row: tuple[Any, ...]) -> None: """Add *row* to the table. Note that all rows must have the same number of columns.""" if self.rows and len(row) != self.ncolumns: @@ -1608,14 +1524,15 @@ class Table: self.rows.append(tuple(str(i) for i in row)) - def _get_alignments(self) -> Tuple[str, ...]: + def _get_alignments(self) -> tuple[str, ...]: # NOTE: If not all alignments were specified, extend alignments with the # last alignment specified - return (self.alignments + return ( + self.alignments + (self.alignments[-1],) * (self.ncolumns - len(self.alignments)) - ) + )[:self.ncolumns] - def _get_column_widths(self, rows) -> Tuple[int, ...]: + def _get_column_widths(self, rows) -> tuple[int, ...]: return tuple( max(len(row[i]) for row in rows) for i in range(self.ncolumns) ) @@ -1642,13 +1559,13 @@ class Table: col_widths = self._get_column_widths(self.rows) lines = [" | ".join([ - cell.center(col_width) if align == "c" - else cell.ljust(col_width) if align == "l" - else cell.rjust(col_width) - for cell, col_width, align in zip(row, col_widths, alignments)]) + cell.center(cwidth) if align == "c" + else cell.ljust(cwidth) if align == "l" + else cell.rjust(cwidth) + for cell, cwidth, align in zip(row, col_widths, alignments, strict=True)]) for row in self.rows] - lines[1:1] = ["+".join("-" * (col_width + 1 + (i > 0)) - for i, col_width in enumerate(col_widths))] + lines[1:1] = ["+".join("-" * (cwidth + 1 + (i > 0)) + for i, cwidth in enumerate(col_widths))] return "\n".join(lines) @@ -1680,22 +1597,23 @@ class Table: col_widths = self._get_column_widths(rows) lines = [" | ".join([ - cell.center(col_width) if align == "c" - else cell.ljust(col_width) if align == "l" - else cell.rjust(col_width) - for cell, col_width, align in zip(row, col_widths, alignments)]) + cell.center(cwidth) if align == "c" + else cell.ljust(cwidth) if align == "l" + else cell.rjust(cwidth) + for cell, cwidth, align in zip(row, col_widths, alignments, strict=True)]) for row in rows] lines[1:1] = ["|".join( - (":" + "-" * (col_width - 1 + (i > 0)) + ":") if align == "c" - else (":" + "-" * (col_width + (i > 0))) if align == "l" - else ("-" * (col_width + (i > 0)) + ":") - for i, (col_width, align) in enumerate(zip(col_widths, alignments)))] + (":" + "-" * (cwidth - 1 + (i > 0)) + ":") if align == "c" + else (":" + "-" * (cwidth + (i > 0))) if align == "l" + else ("-" * (cwidth + (i > 0)) + ":") + for i, (cwidth, align) in enumerate( + zip(col_widths, alignments, strict=True)))] return "\n".join(lines) def csv(self, dialect: str = "excel", - csv_kwargs: Optional[Dict[str, Any]] = None) -> str: + csv_kwargs: dict[str, Any] | None = None) -> str: """Returns a string containing a CSV representation of the table. :arg dialect: String passed to :func:`csv.writer`. @@ -1732,7 +1650,7 @@ class Table: def latex(self, skip_lines: int = 0, - hline_after: Optional[Tuple[int, ...]] = None) -> str: + hline_after: tuple[int, ...] | None = None) -> str: r"""Returns a string containing the rows of a LaTeX representation of the table. @@ -1785,10 +1703,10 @@ class Table: col_widths = self._get_column_widths(self.rows) lines = [" ".join([ - cell.center(col_width) if align == "c" - else cell.ljust(col_width) if align == "l" - else cell.rjust(col_width) - for cell, col_width, align in zip(row, col_widths, alignments)]) + cell.center(cwidth) if align == "c" + else cell.ljust(cwidth) if align == "l" + else cell.rjust(cwidth) + for cell, cwidth, align in zip(row, col_widths, alignments, strict=True)]) for row in self.rows] # Remove the extra space added by the last cell @@ -1798,7 +1716,7 @@ class Table: def merge_tables(*tables: Table, - skip_columns: Optional[Tuple[int, ...]] = None) -> Table: + skip_columns: tuple[int, ...] | None = None) -> Table: """ :arg skip_columns: a :class:`tuple` of column indices to skip in all the tables except the first one. @@ -1889,7 +1807,7 @@ def string_histogram( bin_value, bin_value/total_count*100, format_bar(bin_value)) - for bin_start, bin_value in zip(bin_starts, bins)) + for bin_start, bin_value in zip(bin_starts, bins, strict=True)) # }}} @@ -2011,7 +1929,7 @@ class StderrToStdout: def typedump(val: Any, max_seq: int = 5, - special_handlers: Optional[Mapping[Type, Callable]] = None, + special_handlers: Mapping[type, Callable] | None = None, fully_qualified_name: bool = True) -> str: """ Return a string representation of the type of *val*, recursing into @@ -2262,7 +2180,7 @@ UNIQUE_NAME_GEN_COUNTER_RE = re.compile(r"^(?P<based_on>\w+)_(?P<counter>\d+)$") def generate_numbered_unique_names( - prefix: str, num: Optional[int] = None) -> Iterable[Tuple[int, str]]: + prefix: str, num: int | None = None) -> Iterable[tuple[int, str]]: if num is None: yield (0, prefix) num = 0 @@ -2289,7 +2207,7 @@ class UniqueNameGenerator: .. automethod:: __call__ """ def __init__(self, - existing_names: Optional[Collection[str]] = None, + existing_names: Collection[str] | None = None, forced_prefix: str = ""): """ Create a new :class:`UniqueNameGenerator`. @@ -2303,7 +2221,7 @@ class UniqueNameGenerator: self.existing_names = set(existing_names) self.forced_prefix = forced_prefix - self.prefix_to_counter: Dict[str, int] = {} + self.prefix_to_counter: dict[str, int] = {} def is_name_conflicting(self, name: str) -> bool: """Returns *True* if *name* conflicts with an existing :class:`str`.""" @@ -2773,44 +2691,14 @@ def resolve_name(name): .. versionadded:: 2021.1.2 """ - # Delete the tail of the function and deprecate this once we require Python 3.9. - if sys.version_info >= (3, 9): - # use the official version - import pkgutil - return pkgutil.resolve_name(name) - - import importlib - - m = _NAME_PATTERN.match(name) - if not m: - raise ValueError(f"invalid format: {name!r}") - groups = m.groups() - if groups[2]: - # there is a colon - a one-step import is all that's needed - mod = importlib.import_module(groups[0]) - parts = groups[3].split(".") if groups[3] else [] - else: - # no colon - have to iterate to find the package boundary - parts = name.split(".") - modname = parts.pop(0) - # first part *must* be a module/package. - mod = importlib.import_module(modname) - while parts: - p = parts[0] - s = f"{modname}.{p}" - try: - mod = importlib.import_module(s) - parts.pop(0) - modname = s - except ImportError: - break - # if we reach this point, mod is the module, already imported, and - # parts is the list of parts in the object hierarchy to be traversed, or - # an empty list if just the module is wanted. - result = mod - for p in parts: - result = getattr(result, p) - return result + from warnings import warn + + warn("'pytools.resolve_name' is deprecated and will be removed in 2024. " + "Use 'pkgutil.resolve_name' from the standard library instead.", + DeprecationWarning, stacklevel=2) + + import pkgutil + return pkgutil.resolve_name(name) # }}} @@ -2819,7 +2707,7 @@ def resolve_name(name): def unordered_hash(hash_instance: Any, iterable: Iterable[Any], - hash_constructor: Optional[Callable[[], Any]] = None) -> Any: + hash_constructor: Callable[[], Any] | None = None) -> Any: """Using a hash algorithm given by the parameter-less constructor *hash_constructor*, return a hash object whose internal state depends on the entries of *iterable*, but not their order. If *hash* @@ -2875,7 +2763,7 @@ def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0): """ import numpy as np - points: List[np.ndarray] = [] + points: list[np.ndarray] = [] count = 0 a = 4 * np.pi / npoints_approx @@ -2915,7 +2803,7 @@ _SPHERE_FIBONACCI_OFFSET = ( def sphere_sample_fibonacci( npoints: int, r: float = 1.0, *, - optimize: Optional[str] = None): + optimize: str | None = None): """Generate points on a sphere based on an offset Fibonacci lattice from [2]_. .. [2] http://extremelearning.com.au/how-to-evenly-distribute-points-on-a-sphere-more-effectively-than-the-canonical-fibonacci-lattice/ @@ -2955,7 +2843,7 @@ def sphere_sample_fibonacci( # {{{ strtobool -def strtobool(val: Optional[str], default: Optional[bool] = None) -> bool: +def strtobool(val: str | None, default: bool | None = None) -> bool: """Convert a string representation of truth to True or False. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Uppercase versions are @@ -3062,7 +2950,7 @@ def unique_union(*args: Iterable[T]) -> Collection[T]: if not args: return [] - res: Dict[T, None] = {} + res: dict[T, None] = {} for seq in args: for item in seq: if item not in res: diff --git a/pytools/codegen.py b/pytools/codegen.py index 8f30b75b2b5827e09f163fc0d990a90ea34924da..bd274639cda5d78621497ecac004095a74eddf64 100644 --- a/pytools/codegen.py +++ b/pytools/codegen.py @@ -29,7 +29,7 @@ Tools for Source Code Generation .. autofunction:: remove_common_indentation """ -from typing import Any, List +from typing import Any # {{{ code generation @@ -48,8 +48,8 @@ class CodeGenerator: .. automethod:: dedent """ def __init__(self) -> None: - self.preamble: List[str] = [] - self.code: List[str] = [] + self.preamble: list[str] = [] + self.code: list[str] = [] self.level = 0 self.indent_amount = 4 diff --git a/pytools/convergence.py b/pytools/convergence.py index 7cf96925e6d5431f8685e528aef755041354ae62..ee88f056fa924ffa2536b56e6546ff7e350d6fc0 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -7,7 +7,6 @@ import numbers -from typing import List, Optional, Tuple import numpy as np @@ -45,7 +44,7 @@ class EOCRecorder: """ def __init__(self) -> None: - self.history: List[Tuple[float, float]] = [] + self.history: list[tuple[float, float]] = [] def add_data_point(self, abscissa: float, error: float) -> None: if not (isinstance(abscissa, numbers.Number) @@ -60,7 +59,7 @@ class EOCRecorder: self.history.append((abscissa, error)) def estimate_order_of_convergence(self, - gliding_mean: Optional[int] = None, + gliding_mean: int | None = None, ) -> np.ndarray: abscissae = np.array([a for a, e in self.history]) errors = np.array([e for a, e in self.history]) @@ -159,7 +158,7 @@ class EOCRecorder: def stringify_eocs(*eocs: EOCRecorder, - names: Optional[Tuple[str, ...]] = None, + names: tuple[str, ...] | None = None, abscissa_label: str = "h", error_label: str = "Error", gliding_mean: int = 2, @@ -186,7 +185,7 @@ def stringify_eocs(*eocs: EOCRecorder, error_format=error_format, eoc_format=eoc_format, gliding_mean=gliding_mean) - for name, eoc in zip(names, eocs) + for name, eoc in zip(names, eocs, strict=True) ], skip_columns=(0,)) if table_type == "markdown": @@ -219,7 +218,7 @@ class PConvergenceVerifier: tbl = Table() tbl.add_row(("p", "error")) - for p, err in zip(self.orders, self.errors): + for p, err in zip(self.orders, self.errors, strict=True): tbl.add_row((str(p), str(err))) return str(tbl) diff --git a/pytools/datatable.py b/pytools/datatable.py index ef8b63c0bd01521096419a35e308a6e9b712c74a..e4dd79095d8470e01f4a49e8d9b0fb872522238a 100644 --- a/pytools/datatable.py +++ b/pytools/datatable.py @@ -1,4 +1,5 @@ -from typing import IO, Any, Callable, Iterator, List, Optional, Sequence, Tuple +from collections.abc import Callable, Iterator, Sequence +from typing import IO, Any from pytools import Record @@ -25,7 +26,7 @@ class DataTable: """ def __init__(self, column_names: Sequence[str], - column_data: Optional[List[Any]] = None) -> None: + column_data: list[Any] | None = None) -> None: """Construct a new table, with the given C{column_names}. :arg column_names: An indexable of column name strings. @@ -50,7 +51,7 @@ class DataTable: def __len__(self) -> int: return len(self.data) - def __iter__(self) -> Iterator[List[Any]]: + def __iter__(self) -> Iterator[list[Any]]: return self.data.__iter__() def __str__(self) -> str: @@ -65,7 +66,7 @@ class DataTable: def format_row(row: Sequence[str]) -> str: return "|".join([str(cell).ljust(col_width) - for cell, col_width in zip(row, col_widths)]) + for cell, col_width in zip(row, col_widths, strict=True)]) lines = [format_row(self.column_names), "+".join("-"*col_width for col_width in col_widths)] + \ @@ -80,12 +81,12 @@ class DataTable: self.insert_row(tuple(values)) - def insert_row(self, values: Tuple[Any, ...]) -> None: + def insert_row(self, values: tuple[Any, ...]) -> None: assert isinstance(values, tuple) assert len(values) == len(self.column_names) self.data.append(values) - def insert_rows(self, rows: Sequence[Tuple[Any, ...]]) -> None: + def insert_rows(self, rows: Sequence[tuple[Any, ...]]) -> None: for row in rows: self.insert_row(row) @@ -118,7 +119,7 @@ class DataTable: if len(filtered) > 1: raise RuntimeError("more than one matching entry for get()") - return Row(dict(list(zip(self.column_names, filtered.data[0])))) + return Row(dict(zip(self.column_names, filtered.data[0], strict=True))) def clear(self) -> None: del self.data[:] @@ -140,7 +141,7 @@ class DataTable: def sort(self, columns: Sequence[str], reverse: bool = False) -> None: col_indices = [self.column_indices[col] for col in columns] - def mykey(row: Sequence[Any]) -> Tuple[Any, ...]: + def mykey(row: Sequence[Any]) -> tuple[Any, ...]: return tuple( row[col_index] for col_index in col_indices) @@ -157,8 +158,8 @@ class DataTable: result_data = [] # to pacify pyflakes: - last_values: Tuple[Any, ...] = () - agg_values: List[Row] = [] + last_values: tuple[Any, ...] = () + agg_values: list[Row] = [] for row in self.data: this_values = tuple(row[i] for i in gb_indices) @@ -192,7 +193,7 @@ class DataTable: by which they are joined. """ - def without(indexable: Tuple[str, ...], idx: int) -> Tuple[str, ...]: + def without(indexable: tuple[str, ...], idx: int) -> tuple[str, ...]: return indexable[:idx] + indexable[idx+1:] this_key_idx = self.column_indices[column] @@ -278,7 +279,7 @@ class DataTable: return DataTable(columns, [[row[i] for i in col_indices] for row in self.data]) - def column_data(self, column: str) -> List[Tuple[Any, ...]]: + def column_data(self, column: str) -> list[tuple[Any, ...]]: col_index = self.column_indices[column] return [row[col_index] for row in self.data] diff --git a/pytools/graph.py b/pytools/graph.py index a091d95d38b01d0f14fe0c4324b5fa72917e66d8..9b9816305397ef9b0f3595bf493a1b1c943f2f85 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -70,29 +70,24 @@ Type Variables Used is included as a key in the graph. """ -from dataclasses import dataclass -from typing import ( - Any, +from collections.abc import ( Callable, Collection, - Dict, - FrozenSet, - Generic, Hashable, Iterable, Iterator, - List, Mapping, MutableSet, - Optional, +) +from dataclasses import dataclass +from typing import ( + Any, + Generic, Protocol, - Set, - Tuple, + TypeAlias, TypeVar, ) -from typing_extensions import TypeAlias - NodeT = TypeVar("NodeT", bound=Hashable) @@ -107,7 +102,7 @@ def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]: :returns: A :class:`dict` representing *graph* with edges reversed. """ - result: Dict[NodeT, Set[NodeT]] = {} + result: dict[NodeT, set[NodeT]] = {} for node_key, successor_nodes in graph.items(): # Make sure every node is in the result even if it has no successors @@ -125,9 +120,9 @@ def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]: def a_star( initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT[NodeT], - estimate_remaining_cost: Optional[Callable[[NodeT], float]] = None, + estimate_remaining_cost: Callable[[NodeT], float] | None = None, get_step_cost: Callable[[Any, NodeT], float] = lambda x, y: 1 - ) -> List[NodeT]: + ) -> list[NodeT]: """ With the default cost and heuristic, this amounts to Dijkstra's algorithm. """ @@ -162,7 +157,7 @@ def a_star( if top.state == goal_state: result = [] - it: Optional[AStarNode] = top + it: AStarNode | None = top while it is not None: result.append(it.state) it = it.parent @@ -189,21 +184,20 @@ def a_star( # {{{ compute SCCs with Tarjan's algorithm -def compute_sccs(graph: GraphT[NodeT]) -> List[List[NodeT]]: +def compute_sccs(graph: GraphT[NodeT]) -> list[list[NodeT]]: to_search = set(graph.keys()) - visit_order: Dict[NodeT, int] = {} + visit_order: dict[NodeT, int] = {} scc_root = {} sccs = [] while to_search: top = next(iter(to_search)) - call_stack: List[Tuple[NodeT, Iterator[NodeT], Optional[NodeT]]] = [(top, - iter(graph[top]), - None)] + call_stack: list[tuple[NodeT, Iterator[NodeT], NodeT | None]] = ( + [(top, iter(graph[top]), None)]) visit_stack = [] visiting = set() - scc: List[NodeT] = [] + scc: list[NodeT] = [] while call_stack: top, children, last_popped_child = call_stack.pop() @@ -283,8 +277,8 @@ class _HeapEntry(Generic[NodeT]): def compute_topological_order( graph: GraphT[NodeT], - key: Optional[Callable[[NodeT], _SupportsLT]] = None, - ) -> List[NodeT]: + key: Callable[[NodeT], _SupportsLT] | None = None, + ) -> list[NodeT]: """Compute a topological order of nodes in a directed graph. :arg key: A custom key function may be supplied to determine the order in @@ -408,8 +402,8 @@ def contains_cycle(graph: GraphT[NodeT]) -> bool: # {{{ compute induced subgraph -def compute_induced_subgraph(graph: Mapping[NodeT, Set[NodeT]], - subgraph_nodes: Set[NodeT]) -> GraphT[NodeT]: +def compute_induced_subgraph(graph: Mapping[NodeT, set[NodeT]], + subgraph_nodes: set[NodeT]) -> GraphT[NodeT]: """Compute the induced subgraph formed by a subset of the vertices in a graph. @@ -439,9 +433,9 @@ def compute_induced_subgraph(graph: Mapping[NodeT, Set[NodeT]], # {{{ as_graphviz_dot def as_graphviz_dot(graph: GraphT[NodeT], - node_labels: Optional[Callable[[NodeT], str]] = None, - edge_labels: Optional[Callable[[NodeT, NodeT], str]] = None) \ - -> str: + node_labels: Callable[[NodeT], str] | None = None, + edge_labels: Callable[[NodeT, NodeT], str] | None = None, + ) -> str: """ Create a visualization of the graph *graph* in the `dot <http://graphviz.org/>`__ language. @@ -502,7 +496,7 @@ def validate_graph(graph: GraphT[NodeT]) -> None: Validates that all successor nodes of each node in *graph* are keys in *graph* itself. Raises a :class:`ValueError` if not. """ - seen_nodes: Set[NodeT] = set() + seen_nodes: set[NodeT] = set() for children in graph.values(): seen_nodes.update(children) @@ -549,7 +543,7 @@ def is_connected(graph: GraphT[NodeT]) -> bool: def undirected_graph_from_edges( - edges: Iterable[Tuple[NodeT, NodeT]], + edges: Iterable[tuple[NodeT, NodeT]], ) -> GraphT[NodeT]: """ Constructs an undirected graph using *edges*. @@ -558,7 +552,7 @@ def undirected_graph_from_edges( :returns: A :class:`GraphT` that is the undirected graph. """ - undirected_graph: Dict[NodeT, Set[NodeT]] = {} + undirected_graph: dict[NodeT, set[NodeT]] = {} for lhs, rhs in edges: if lhs == rhs: @@ -573,12 +567,12 @@ def undirected_graph_from_edges( def get_reachable_nodes( undirected_graph: GraphT[NodeT], - source_node: NodeT) -> FrozenSet[NodeT]: + source_node: NodeT) -> frozenset[NodeT]: """ Returns a :class:`frozenset` of all nodes in *undirected_graph* that are reachable from *source_node*. """ - nodes_visited: Set[NodeT] = set() + nodes_visited: set[NodeT] = set() nodes_to_visit = {source_node} while nodes_to_visit: diff --git a/pytools/mpi.py b/pytools/mpi.py index f74c130751c3dbb83720303249b8c97858154756..b446c25472b4f323d6cd0eb6c8bec0aa2991014c 100644 --- a/pytools/mpi.py +++ b/pytools/mpi.py @@ -32,8 +32,8 @@ MPI helper functionality .. autofunction:: pytest_raises_on_rank """ +from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager -from typing import Generator, Tuple, Type, Union def check_for_mpi_relaunch(argv): @@ -67,10 +67,10 @@ def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None): @contextmanager -def pytest_raises_on_rank(my_rank: int, fail_rank: int, - expected_exception: Union[Type[BaseException], - Tuple[Type[BaseException], ...]]) \ - -> Generator[AbstractContextManager, None, None]: +def pytest_raises_on_rank( + my_rank: int, fail_rank: int, + expected_exception: type[BaseException] | tuple[type[BaseException], ...], + ) -> Generator[AbstractContextManager, None, None]: """ Like :func:`pytest.raises`, but only expect an exception on rank *fail_rank*. """ diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 3f89887f046d9c7acab623c2fbf3b598d78b1810..62a7f5d16761b0482069c3429fb7f3e7fb302611 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -35,21 +35,10 @@ import os import pickle import sqlite3 import sys +from collections.abc import Callable, Iterator, Mapping from dataclasses import fields as dc_fields, is_dataclass from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - Callable, - FrozenSet, - Iterator, - Mapping, - Optional, - Protocol, - Tuple, - TypeVar, - cast, -) +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast from warnings import warn @@ -318,11 +307,11 @@ class KeyBuilder: def update_for_bytes(key_hash: Hash, key: bytes) -> None: key_hash.update(key) - def update_for_tuple(self, key_hash: Hash, key: Tuple[Any, ...]) -> None: + def update_for_tuple(self, key_hash: Hash, key: tuple[Any, ...]) -> None: for obj_i in key: self.rec(key_hash, obj_i) - def update_for_frozenset(self, key_hash: Hash, key: FrozenSet[Any]) -> None: + def update_for_frozenset(self, key_hash: Hash, key: frozenset[Any]) -> None: from pytools import unordered_hash unordered_hash( @@ -469,10 +458,10 @@ V = TypeVar("V") class _PersistentDictBase(Mapping[K, V]): def __init__(self, identifier: str, - key_builder: Optional[KeyBuilder] = None, - container_dir: Optional[str] = None, + key_builder: KeyBuilder | None = None, + container_dir: str | None = None, enable_wal: bool = False, - safe_sync: Optional[bool] = None) -> None: + safe_sync: bool | None = None) -> None: self.identifier = identifier self.conn = None @@ -580,7 +569,7 @@ class _PersistentDictBase(Mapping[K, V]): return cursor - def _exec_sql_fn(self, fn: Callable[[], T]) -> Optional[T]: + def _exec_sql_fn(self, fn: Callable[[], T]) -> T | None: n = 0 with self.mutex: @@ -643,7 +632,7 @@ class _PersistentDictBase(Mapping[K, V]): for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0])[1] - def items(self) -> Iterator[Tuple[K, V]]: # type: ignore[override] + def items(self) -> Iterator[tuple[K, V]]: # type: ignore[override] """Return an iterator over the items in the dictionary.""" for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0]) @@ -689,11 +678,11 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): .. automethod:: fetch """ def __init__(self, identifier: str, - key_builder: Optional[KeyBuilder] = None, - container_dir: Optional[str] = None, + key_builder: KeyBuilder | None = None, + container_dir: str | None = None, *, enable_wal: bool = False, - safe_sync: Optional[bool] = None, + safe_sync: bool | None = None, in_mem_cache_size: int = 256) -> None: """ :arg identifier: a filename-compatible string identifying this @@ -747,10 +736,10 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): raise ReadOnlyEntryError("WriteOncePersistentDict, " "tried overwriting key") from e - def _fetch_uncached(self, keyhash: str) -> Tuple[K, V]: + def _fetch_uncached(self, keyhash: str) -> tuple[K, V]: # This method is separate from fetch() to allow for LRU caching - def fetch_inner() -> Optional[Tuple[Any]]: + def fetch_inner() -> tuple[Any] | None: assert self.conn is not None # This is separate from fetch() so that the mutex covers the @@ -806,11 +795,11 @@ class PersistentDict(_PersistentDictBase[K, V]): """ def __init__(self, identifier: str, - key_builder: Optional[KeyBuilder] = None, - container_dir: Optional[str] = None, + key_builder: KeyBuilder | None = None, + container_dir: str | None = None, *, enable_wal: bool = False, - safe_sync: Optional[bool] = None) -> None: + safe_sync: bool | None = None) -> None: """ :arg identifier: a filename-compatible string identifying this dictionary @@ -840,7 +829,7 @@ class PersistentDict(_PersistentDictBase[K, V]): def fetch(self, key: K) -> V: keyhash = self.key_builder(key) - def fetch_inner() -> Optional[Tuple[Any]]: + def fetch_inner() -> tuple[Any] | None: assert self.conn is not None # This is separate from fetch() so that the mutex covers the diff --git a/pytools/spatial_btree.py b/pytools/spatial_btree.py index e1f2e91bc7a9ee4dc7b3be7080c6243201fd8a1b..83b466c070666416260e8367c69373ccfa9a79ed 100644 --- a/pytools/spatial_btree.py +++ b/pytools/spatial_btree.py @@ -166,7 +166,7 @@ class SpatialBinaryTreeBucket: (Path.CLOSEPOLY, (el[0], el[1])), ] - codes, verts = zip(*pathdata) + codes, verts = zip(*pathdata, strict=True) path = Path(verts, codes) patch = mpatches.PathPatch(path, **kwargs) pt.gca().add_patch(patch) diff --git a/pytools/stopwatch.py b/pytools/stopwatch.py index d0c9234d98310ec1b478f526d8b2c685f0a0a414..9887b7339f2306e9385667e86751e2938545c076 100644 --- a/pytools/stopwatch.py +++ b/pytools/stopwatch.py @@ -1,5 +1,4 @@ import time -from typing import List, Optional from pytools import DependentDictionary, Reference @@ -7,7 +6,7 @@ from pytools import DependentDictionary, Reference class StopWatch: def __init__(self) -> None: self.Elapsed = 0.0 - self.LastStart: Optional[float] = None + self.LastStart: float | None = None def start(self) -> "StopWatch": assert self.LastStart is None @@ -57,7 +56,7 @@ class EtaEstimator: self.total_steps = total_steps assert total_steps > 0 - def estimate(self, done: int) -> Optional[float]: + def estimate(self, done: int) -> float | None: fraction_done = done / self.total_steps time_spent = self.stopwatch.elapsed() @@ -72,7 +71,7 @@ def print_job_summary() -> None: print(key, " " * (50 - len(key)), value) -HIDDEN_JOBS: List[str] = [] -VISIBLE_JOBS: List[str] = [] +HIDDEN_JOBS: list[str] = [] +VISIBLE_JOBS: list[str] = [] JOB_TIMES = DependentDictionary(lambda x: 0) PRINT_JOBS = Reference(True) diff --git a/pytools/tag.py b/pytools/tag.py index 97073de5ad803025c783c80c431133c7df44e1f5..3ebf9cc295d4d182603e2044a4e95de5536cd7a3 100644 --- a/pytools/tag.py +++ b/pytools/tag.py @@ -25,18 +25,9 @@ Internal stuff that is only here because the documentation tool wants it from __future__ import annotations +from collections.abc import Iterable from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - FrozenSet, - Iterable, - Set, - Tuple, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, TypeVar from warnings import warn from typing_extensions import Self, dataclass_transform @@ -90,7 +81,7 @@ class DottedName: .. automethod:: from_class """ - def __init__(self, name_parts: Tuple[str, ...]) -> None: + def __init__(self, name_parts: tuple[str, ...]) -> None: if len(name_parts) == 0: raise ValueError("empty name parts") @@ -175,18 +166,18 @@ class UniqueTag(Tag): # }}} -ToTagSetConvertible = Union[Iterable[Tag], Tag, None] +ToTagSetConvertible = Iterable[Tag] | Tag | None TagT = TypeVar("TagT", bound="Tag") # {{{ UniqueTag rules checking @memoize -def _immediate_unique_tag_descendants(cls: type[Tag]) -> FrozenSet[type[Tag]]: +def _immediate_unique_tag_descendants(cls: type[Tag]) -> frozenset[type[Tag]]: if UniqueTag in cls.__bases__: return frozenset([cls]) else: - result: FrozenSet[type[Tag]] = frozenset() + result: frozenset[type[Tag]] = frozenset() for base in cls.__bases__: result = result | _immediate_unique_tag_descendants(base) return result @@ -201,14 +192,14 @@ class NonUniqueTagError(ValueError): pass -def check_tag_uniqueness(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: +def check_tag_uniqueness(tags: frozenset[Tag]) -> frozenset[Tag]: """Ensure that *tags* obeys the rules set forth in :class:`UniqueTag`. If not, raise :exc:`NonUniqueTagError`. If any *tags* are not subclasses of :class:`Tag`, a :exc:`TypeError` will be raised. :returns: *tags* """ - unique_tag_descendants: Set[type[Tag]] = set() + unique_tag_descendants: set[type[Tag]] = set() for tag in tags: if not isinstance(tag, Tag): raise TypeError(f"'{tag}' is not an instance of pytools.tag.Tag") @@ -227,7 +218,7 @@ def check_tag_uniqueness(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: # }}} -def normalize_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]: +def normalize_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: if isinstance(tags, Tag): tags = frozenset([tags]) elif tags is None: @@ -255,7 +246,7 @@ class Taggable: """ if not TYPE_CHECKING: - def __init__(self, tags: FrozenSet[Tag] = frozenset()): + def __init__(self, tags: frozenset[Tag] = frozenset()): warn("The Taggable constructor is deprecated. " "Subclasses must declare their own storage for .tags. " "The constructor will disappear in 2025.x.", @@ -269,10 +260,10 @@ class Taggable: # type-checking only so that self.tags = ... in subclasses still works if TYPE_CHECKING: @property - def tags(self) -> FrozenSet[Tag]: + def tags(self) -> frozenset[Tag]: ... - def _with_new_tags(self, tags: FrozenSet[Tag]) -> Self: + def _with_new_tags(self, tags: frozenset[Tag]) -> Self: """ Returns a copy of *self* with the specified tags. This method should be overridden by subclasses. @@ -313,7 +304,7 @@ class Taggable: return self._with_new_tags(tags=check_tag_uniqueness(new_tags)) @memoize_method - def tags_of_type(self, tag_t: Type[TagT]) -> FrozenSet[TagT]: + def tags_of_type(self, tag_t: type[TagT]) -> frozenset[TagT]: """ Returns *self*'s tags of type *tag_t*. """ @@ -322,7 +313,7 @@ class Taggable: if isinstance(tag, tag_t)}) @memoize_method - def tags_not_of_type(self, tag_t: Type[TagT]) -> FrozenSet[Tag]: + def tags_not_of_type(self, tag_t: type[TagT]) -> frozenset[Tag]: """ Returns *self*'s tags that are not of type *tag_t*. """ @@ -346,8 +337,8 @@ class Taggable: _depr_name_to_replacement_and_obj = { "TagsType": ( - "FrozenSet[Tag]", - FrozenSet[Tag], 2023), + "frozenset[Tag]", + frozenset[Tag], 2023), "TagOrIterableType": ( "ToTagSetConvertible", ToTagSetConvertible, 2023), diff --git a/pytools/test/test_data_table.py b/pytools/test/test_data_table.py index c259765715a054d644ea23c5cd4b40ee3d8a024e..7b40baa6032a683029007a0f02c1775ac6ab63ac 100644 --- a/pytools/test/test_data_table.py +++ b/pytools/test/test_data_table.py @@ -70,7 +70,7 @@ def test_aggregate(): def test_aggregate_2(): from pytools.datatable import DataTable - tbl = DataTable(["step", "value"], list(zip(list(range(20)), list(range(20))))) + tbl = DataTable(["step", "value"], list(zip(range(20), range(20), strict=True))) agg = tbl.aggregated(["step"], "value", max) assert agg.column_data("step") == list(range(20)) assert agg.column_data("value") == list(range(20)) diff --git a/pytools/test/test_graph_tools.py b/pytools/test/test_graph_tools.py index 57c462959964772338694ab2ff8f4ab6fa3e4f8b..8284766e5d354119c0611b7067d6619a9860fe8c 100644 --- a/pytools/test/test_graph_tools.py +++ b/pytools/test/test_graph_tools.py @@ -57,7 +57,7 @@ def test_compute_topological_order(): disconnected = {1: [], 2: [], 3: []} assert len(compute_topological_order(disconnected)) == 3 - line = list(zip(range(10), ([i] for i in range(1, 11)))) + line = list(zip(range(10), ([i] for i in range(1, 11)), strict=True)) import random random.seed(0) random.shuffle(line) diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index b04d43a09c41cfcb6189fd51aab93a9e8c55f5cf..727e589edf1bb54f189530daf10e7c3eb21dfeb6 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -3,7 +3,7 @@ import sys import tempfile from dataclasses import dataclass from enum import Enum, IntEnum -from typing import Any, Dict, Optional +from typing import Any, Optional import pytest @@ -29,7 +29,7 @@ class PDictTestingKeyOrValue: hash_key = val self.hash_key = hash_key - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: return {"val": self.val, "hash_key": self.hash_key} def __eq__(self, other: Any) -> bool: @@ -87,11 +87,11 @@ def test_persistent_dict_storage_and_lookup() -> None: for i in range(20)] values = [randrange(2000) for i in range(20)] - d = dict(zip(keys, values)) + d = dict(zip(keys, values, strict=True)) # {{{ check lookup - for k, v in zip(keys, values): + for k, v in zip(keys, values, strict=True): pdict[k] = v for k, v in d.items(): @@ -102,7 +102,7 @@ def test_persistent_dict_storage_and_lookup() -> None: # {{{ check updating - for k, v in zip(keys, values): + for k, v in zip(keys, values, strict=True): pdict[k] = v + 1 for k, v in d.items(): @@ -113,7 +113,7 @@ def test_persistent_dict_storage_and_lookup() -> None: # {{{ check store_if_not_present - for k, _ in zip(keys, values): + for k, _ in zip(keys, values, strict=True): pdict.store_if_not_present(k, d[k] + 2) for k, v in d.items(): @@ -809,10 +809,10 @@ def test_keys_values_items(): pdict[i] = i # This also tests deterministic iteration order - assert len(list(pdict.keys())) == 10000 == len(set(pdict.keys())) + assert len(pdict) == 10000 == len(set(pdict)) assert list(pdict.keys()) == list(range(10000)) assert list(pdict.values()) == list(range(10000)) - assert list(pdict.items()) == list(zip(list(pdict.keys()), range(10000))) + assert list(pdict.items()) == list(zip(pdict, range(10000), strict=True)) assert ([k for k in pdict.keys()] # noqa: C416 == list(pdict.keys()) diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index c262fc91264700a43e3406ddb886f9d1e047b6ec..4babe70c42a2184d62c3320ca112e1eed8c830c7 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -24,7 +24,6 @@ THE SOFTWARE. import logging import sys from dataclasses import dataclass -from typing import FrozenSet import pytest @@ -509,7 +508,7 @@ def test_tag() -> None: # Need a subclass that defines the copy function in order to test. @tag_dataclass class TaggableWithCopy(Taggable): - tags: FrozenSet[Tag] + tags: frozenset[Tag] def _with_new_tags(self, tags): return TaggableWithCopy(tags)