diff --git a/pytools/__init__.py b/pytools/__init__.py index cd0fb0d267d41fb8ca2d1a94910423b4ac5b39d6..233974ce946d8d46aaa15c312377ee53a1d77436 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -27,7 +27,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import re from functools import reduce, wraps import operator @@ -35,11 +34,16 @@ import sys import logging from typing import ( cast, Any, Callable, Dict, Hashable, Iterable, - List, Optional, Set, Tuple, TypeVar) + List, Optional, Set, Tuple, TypeVar, Union) import builtins +import math from sys import intern +try: + from typing import SupportsIndex +except ImportError: + from typing_extensions import SupportsIndex # These are deprecated and will go away in 2022. all = builtins.all @@ -198,7 +202,7 @@ F = TypeVar("F", bound=Callable[..., Any]) # {{{ code maintenance class MovedFunctionDeprecationWrapper: - def __init__(self, f, deadline=None): + def __init__(self, f: F, deadline: Optional[Union[int, str]] = None) -> None: if deadline is None: deadline = "the future" @@ -267,43 +271,85 @@ def delta(x, y): return 0 -def levi_civita(tup): - """Compute an entry of the Levi-Civita tensor for the indices *tuple*.""" +def levi_civita(tup: Tuple[int, ...]) -> int: + """Compute an entry of the Levi-Civita symbol for the indices *tuple*.""" if len(tup) == 2: i, j = tup - return j-i + return j - i if len(tup) == 3: i, j, k = tup - return (j-i)*(k-i)*(k-j)/2 + return (j-i) * (k-i) * (k-j) // 2 else: - raise NotImplementedError + raise NotImplementedError(f"Levi-Civita symbol in {len(tup)} dimensions") -def factorial(n): - from operator import mul - assert n == int(n) - return reduce(mul, (i for i in range(1, n+1)), 1) +factorial = MovedFunctionDeprecationWrapper(math.factorial, deadline=2023) +try: + # NOTE: only available in python >= 3.8 + perm = MovedFunctionDeprecationWrapper(math.perm, deadline=2023) +except AttributeError: + def _unchecked_perm(n, k): + result = 1 + while k: + result *= n + n -= 1 + k -= 1 -def perm(n, k): - """Return P(n, k), the number of permutations of length k drawn from n - choices. - """ - result = 1 - assert k > 0 - while k: - result *= n - n -= 1 - k -= 1 + return result - return result + def perm(n: SupportsIndex, # type: ignore[misc] + k: Optional[SupportsIndex] = None) -> int: + """ + :returns: :math:`P(n, k)`, the number of permutations of length :math:`k` + drawn from :math:`n` choices. + """ + from warnings import warn + warn("This function is deprecated and will go away in 2023. " + "Use `math.perm` instead, which is available from Python 3.8.", + DeprecationWarning, stacklevel=2) + if k is None: + return math.factorial(n) -def comb(n, k): - """Return C(n, k), the number of combinations (subsets) - of length k drawn from n choices. - """ - return perm(n, k)//factorial(k) + import operator + n, k = operator.index(n), operator.index(k) + if k > n: + return 0 + + if k < 0: + raise ValueError("k must be a non-negative integer") + + if n < 0: + raise ValueError("n must be a non-negative integer") + + from numbers import Integral + if not isinstance(k, Integral): + raise TypeError(f"'{type(k).__name__}' object cannot be interpreted " + "as an integer") + + if not isinstance(n, Integral): + raise TypeError(f"'{type(n).__name__}' object cannot be interpreted " + "as an integer") + + return _unchecked_perm(n, k) + +try: + # NOTE: only available in python >= 3.8 + comb = MovedFunctionDeprecationWrapper(math.comb, deadline=2023) +except AttributeError: + def comb(n: SupportsIndex, # type: ignore[misc] + k: SupportsIndex) -> int: + """ + :returns: :math:`C(n, k)`, the number of combinations (subsets) + of length :math:`k` drawn from :math:`n` choices. + """ + from warnings import warn + warn("This function is deprecated and will go away in 2023. " + "Use `math.comb` instead, which is available from Python 3.8.", + DeprecationWarning, stacklevel=2) + + return _unchecked_perm(n, k) // math.factorial(k) def norm_1(iterable):