Skip to content
#! /bin/sh
rsync --verbose --archive --delete _build/html/* doc-upload:doc/pytools
rsync --verbose --archive --delete _build/html/ doc-upload:doc/pytools
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "pytools"
version = "2024.1.21"
description = "A collection of tools for Python"
readme = "README.rst"
license = { text = "MIT" }
authors = [
{ name = "Andreas Kloeckner", email = "inform@tiker.net" },
]
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Other Audience",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Scientific/Engineering :: Visualization",
"Topic :: Software Development :: Libraries",
"Topic :: Utilities",
]
dependencies = [
"platformdirs>=2.2",
# for dataclass_transform with frozen_default
"typing-extensions>=4.5",
]
[project.optional-dependencies]
numpy = [
"numpy>=1.6",
]
test = [
"mypy",
"pytest",
"ruff",
]
siphash = [
"siphash24>=1.6",
]
[project.urls]
Documentation = "https://documen.tician.de/pytools/"
Homepage = "https://github.com/inducer/pytools/"
[tool.hatch.build.targets.sdist]
exclude = [
"/.git*",
"/doc/_build",
"/.editorconfig",
"/run-*.sh",
]
[tool.ruff]
preview = true
[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear
"C", # flake8-comprehensions
"E", # pycodestyle
"F", # pyflakes
"G", # flake8-logging-format
"I", # flake8-isort
"N", # pep8-naming
"NPY", # numpy
"Q", # flake8-quotes
"UP", # pyupgrade
"RUF", # ruff
"W", # pycodestyle
"TC",
]
extend-ignore = [
"C90", # McCabe complexity
"E221", # multiple spaces before operator
"E226", # missing whitespace around arithmetic operator
"E402", # module-level import not at top of file
"UP031", # use f-strings instead of %
"UP032", # use f-strings instead of .format
]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
inline-quotes = "double"
multiline-quotes = "double"
[tool.ruff.lint.isort]
combine-as-imports = true
known-local-folder = [
"pytools",
]
lines-after-imports = 2
required-imports = ["from __future__ import annotations"]
[tool.ruff.lint.pep8-naming]
extend-ignore-names = ["update_for_*"]
[tool.mypy]
python_version = "3.10"
ignore_missing_imports = true
warn_unused_ignores = true
# TODO: enable this at some point
# check_untyped_defs = true
[tool.typos.default]
extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$"
]
# pylint: disable=too-many-lines
# (Yes, it has a point!)
from __future__ import annotations
from __future__ import division, absolute_import, print_function
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__copyright__ = """
Copyright (C) 2009-2013 Andreas Kloeckner
Copyright (C) 2013- University of Illinois Board of Trustees
Copyright (C) 2020 Matt Wala
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
......@@ -25,18 +27,30 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import builtins
import logging
import operator
import re
import sys
import logging
from functools import reduce
from collections.abc import Callable, Collection, Hashable, Iterable, Mapping, Sequence
from functools import reduce, wraps
from sys import intern
from typing import (
Any,
ClassVar,
Concatenate,
Generic,
ParamSpec,
TypeVar,
)
import six
from six.moves import range, zip, intern, input
from typing_extensions import dataclass_transform
decorator_module = __import__("decorator", level=0)
my_decorator = decorator_module.decorator
# These are deprecated and will go away in 2022.
all = builtins.all
any = builtins.any
__doc__ = """
A Collection of Utilities
......@@ -46,8 +60,6 @@ Math
----
.. autofunction:: levi_civita
.. autofunction:: perm
.. autofunction:: comb
Assertive accessors
-------------------
......@@ -63,8 +75,10 @@ Memoization
.. autofunction:: memoize
.. autofunction:: memoize_on_first_arg
.. autofunction:: memoize_method
.. autofunction:: memoize_method_with_uncached
.. autofunction:: memoize_in
.. autofunction:: keyed_memoize_on_first_arg
.. autofunction:: keyed_memoize_method
.. autofunction:: keyed_memoize_in
Argmin/max
----------
......@@ -83,24 +97,17 @@ Permutations, Tuples, Integer sequences
---------------------------------------
.. autofunction:: wandering_element
.. autofunction:: indices_in_shape
.. autofunction:: generate_nonnegative_integer_tuples_below
.. autofunction:: generate_nonnegative_integer_tuples_summing_to_at_most
.. autofunction:: generate_all_nonnegative_integer_tuples
.. autofunction:: generate_all_integer_tuples_below
.. autofunction:: generate_all_integer_tuples
.. autofunction:: generate_permutations
.. autofunction:: generate_unique_permutations
Graph Algorithms
----------------
.. autofunction:: a_star
Formatting
----------
.. autoclass:: Table
.. autofunction:: merge_tables
.. autofunction:: string_histogram
.. autofunction:: word_wrap
......@@ -122,6 +129,12 @@ Name generation
.. autofunction:: generate_numbered_unique_names
.. autoclass:: UniqueNameGenerator
Deprecation Warnings
--------------------
.. autofunction:: deprecate_keyword
.. autofunction:: module_getattr_for_deprecations
Functions for dealing with (large) auxiliary files
--------------------------------------------------
......@@ -144,55 +157,190 @@ Log utilities
.. autoclass:: ProcessLogger
.. autoclass:: DebugProcessLogger
.. autoclass:: log_process
Sorting in natural order
------------------------
.. autofunction:: natorder
.. autofunction:: natsorted
Backports of newer Python functionality
---------------------------------------
.. autofunction:: resolve_name
Hashing
-------
.. autofunction:: unordered_hash
Sampling
--------
.. autofunction:: sphere_sample_equidistant
.. autofunction:: sphere_sample_fibonacci
String utilities
----------------
.. autofunction:: strtobool
.. autofunction:: to_identifier
Set-like functions for iterables
--------------------------------
These functions provide set-like operations on iterables. In contrast to
Python's built-in set type, they maintain the internal order of elements.
.. autofunction:: unique
.. autofunction:: unique_difference
.. autofunction:: unique_intersection
.. autofunction:: unique_union
Functionality for dataclasses
-----------------------------
.. autofunction:: opt_frozen_dataclass
Type Variables Used
-------------------
.. class:: T
.. class:: R
Generic unbound invariant :class:`typing.TypeVar`.
.. class:: F
Generic invariant :class:`typing.TypeVar` bound to a :class:`typing.Callable`.
.. class:: P
Generic unbound invariant :class:`typing.ParamSpec`.
"""
# {{{ type variables
# {{{ math --------------------------------------------------------------------
T = TypeVar("T")
R = TypeVar("R")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
def delta(x, y):
if x == y:
return 1
else:
return 0
# }}}
def levi_civita(tup):
"""Compute an entry of the Levi-Civita tensor for the indices *tuple*."""
if len(tup) == 2:
i, j = tup
return j-i
if len(tup) == 3:
i, j, k = tup
return (j-i)*(k-i)*(k-j)/2
else:
raise NotImplementedError
# {{{ code maintenance
# Undocumented on purpose for now, unclear that this is a great idea, given
# that typing.deprecated exists.
class MovedFunctionDeprecationWrapper:
def __init__(self, f: F, deadline: int | str | None = None) -> None:
if deadline is None:
deadline = "the future"
def factorial(n):
from operator import mul
assert n == int(n)
return reduce(mul, (i for i in range(1, n+1)), 1)
self.f = f
self.deadline = deadline
def __call__(self, *args, **kwargs):
from warnings import warn
warn(f"This function is deprecated and will go away in {self.deadline}. "
f"Use {self.f.__module__}.{self.f.__name__} instead.",
DeprecationWarning, stacklevel=2)
return self.f(*args, **kwargs)
def perm(n, k):
"""Return P(n, k), the number of permutations of length k drawn from n
choices.
def deprecate_keyword(oldkey: str,
newkey: str | None = None, *,
deadline: str | None = None):
"""Decorator used to deprecate function keyword arguments.
:arg oldkey: deprecated argument name.
:arg newkey: new argument name that serves the same purpose, if any.
:arg deadline: expected time frame for the removal of the deprecated argument.
"""
result = 1
assert k > 0
while k:
result *= n
n -= 1
k -= 1
from warnings import warn
return result
if deadline is None:
deadline = "the future"
def wrapper(func):
@wraps(func)
def inner_wrapper(*args, **kwargs):
if oldkey in kwargs:
if newkey is None:
warn(f"The '{oldkey}' keyword is deprecated and will "
f"go away in {deadline}.",
DeprecationWarning, stacklevel=2)
else:
warn(f"The '{oldkey}' keyword is deprecated and will "
f"go away in {deadline}. "
f"Use '{newkey}' instead.",
DeprecationWarning, stacklevel=2)
if newkey in kwargs:
raise ValueError(f"Cannot use '{oldkey}' "
f"and '{newkey}' in the same call.")
kwargs[newkey] = kwargs[oldkey]
del kwargs[oldkey]
return func(*args, **kwargs)
return inner_wrapper
def comb(n, k):
"""Return C(n, k), the number of combinations (subsets)
of length k drawn from n choices.
return wrapper
def module_getattr_for_deprecations(
module_name: str,
depr_name_to_replacement_and_obj: Mapping[
str, tuple[str, object, str | int]
],
name: str
) -> object:
"""A helper to construct module-level :meth:`object.__getattr__` functions
so that deprecated names can still be found but raise a warning.
The typical usage pattern is as follows::
__getattr__ = partial(module_getattr_for_deprecations, __name__, {
"OldName": ("NewName", NewName, 2026),
})
"""
return perm(n, k)//factorial(k)
replacement_and_obj = depr_name_to_replacement_and_obj.get(name, None)
if replacement_and_obj is not None:
replacement, obj, deadline = replacement_and_obj
from warnings import warn
warn(f"'{module_name}.{name}' is deprecated. "
f"Use '{replacement}' instead. "
f"'{module_name}.{name}' will continue to work until {deadline}.",
DeprecationWarning, stacklevel=2)
return obj
raise AttributeError(name)
# }}}
# {{{ math
def delta(x, y):
if x == y:
return 1
return 0
def levi_civita(tup: tuple[int, ...]) -> int:
"""Compute an entry of the Levi-Civita symbol for the indices *tuple*."""
if len(tup) == 2:
i, j = tup
return j - i
if len(tup) == 3:
i, j, k = tup
return (j-i) * (k-i) * (k-j) // 2
raise NotImplementedError(f"Levi-Civita symbol in {len(tup)} dimensions")
def norm_1(iterable):
......@@ -211,7 +359,7 @@ def norm_p(iterable, p):
return sum(i**p for i in iterable)**(1/p)
class Norm(object):
class Norm:
def __init__(self, p):
self.p = p
......@@ -225,14 +373,18 @@ class Norm(object):
# {{{ record
class RecordWithoutPickling(object):
class RecordWithoutPickling:
"""An aggregate of named sub-variables. Assumes that each record sub-type
will be individually derived from this class.
"""
__slots__ = []
__slots__: ClassVar[list[str]] = []
fields: ClassVar[set[str]]
def __init__(self, valuedict=None, exclude=None, **kwargs):
def __init__(self,
valuedict: Mapping[str, Any] | None = None,
exclude: Sequence[str] | None = None,
**kwargs: Any) -> None:
assert self.__class__ is not Record
if exclude is None:
......@@ -246,7 +398,7 @@ class RecordWithoutPickling(object):
if valuedict is not None:
kwargs.update(valuedict)
for key, value in six.iteritems(kwargs):
for key, value in kwargs.items():
if key not in exclude:
fields.add(key)
setattr(self, key, value)
......@@ -264,10 +416,10 @@ class RecordWithoutPickling(object):
return self.__class__(**self.get_copy_kwargs(**kwargs))
def __repr__(self):
return "%s(%s)" % (
return "{}({})".format(
self.__class__.__name__,
", ".join("%s=%r" % (fld, getattr(self, fld))
for fld in self.__class__.fields
", ".join(f"{fld}={getattr(self, fld)!r}"
for fld in sorted(self.__class__.fields)
if hasattr(self, fld)))
def register_fields(self, new_fields):
......@@ -282,18 +434,18 @@ class RecordWithoutPickling(object):
# This method is implemented to avoid pylint 'no-member' errors for
# attribute access.
raise AttributeError(
"'%s' object has no attribute '%s'" % (
"'{}' object has no attribute '{}'".format(
self.__class__.__name__, name))
class Record(RecordWithoutPickling):
__slots__ = []
__slots__: ClassVar[list[str]] = []
def __getstate__(self):
return dict(
(key, getattr(self, key))
return {
key: getattr(self, key)
for key in self.__class__.fields
if hasattr(self, key))
if hasattr(self, key)}
def __setstate__(self, valuedict):
try:
......@@ -301,11 +453,13 @@ class Record(RecordWithoutPickling):
except AttributeError:
self.__class__.fields = fields = set()
for key, value in six.iteritems(valuedict):
for key, value in valuedict.items():
fields.add(key)
setattr(self, key, value)
def __eq__(self, other):
if self is other:
return True
return (self.__class__ == other.__class__
and self.__getstate__() == other.__getstate__())
......@@ -314,12 +468,20 @@ class Record(RecordWithoutPickling):
class ImmutableRecordWithoutPickling(RecordWithoutPickling):
"Hashable record. Does not explicitly enforce immutability."
"""Hashable record. Does not explicitly enforce immutability."""
def __init__(self, *args, **kwargs):
RecordWithoutPickling.__init__(self, *args, **kwargs)
self._cached_hash = None
def __hash__(self):
return hash(
(type(self),) + tuple(getattr(self, field)
for field in self.__class__.fields))
# This attribute may vanish during pickling.
if getattr(self, "_cached_hash", None) is None:
self._cached_hash = hash((
type(self),
*(getattr(self, field) for field in self.__class__.fields)
))
return self._cached_hash
class ImmutableRecord(ImmutableRecordWithoutPickling, Record):
......@@ -328,20 +490,21 @@ class ImmutableRecord(ImmutableRecordWithoutPickling, Record):
# }}}
class Reference(object):
class Reference:
def __init__(self, value):
self.value = value
def get(self):
from warnings import warn
warn("Reference.get() is deprecated -- use ref.value instead")
warn("Reference.get() is deprecated -- use ref.value instead. "
"This will stop working in 2025.", stacklevel=2)
return self.value
def set(self, value):
self.value = value
class FakeList(object):
class FakeList:
def __init__(self, f, length):
self._Length = length
self._Function = f
......@@ -357,9 +520,9 @@ class FakeList(object):
return self._Function(index)
# {{{ dependent dictionary ----------------------------------------------------
# {{{ dependent dictionary
class DependentDictionary(object):
class DependentDictionary:
def __init__(self, f, start=None):
if start is None:
start = {}
......@@ -372,7 +535,7 @@ class DependentDictionary(object):
def __contains__(self, key):
try:
self[key] # pylint: disable=pointless-statement
self[key]
return True
except KeyError:
return False
......@@ -386,17 +549,17 @@ class DependentDictionary(object):
def __setitem__(self, key, value):
self._Dictionary[key] = value
def genuineKeys(self): # noqa
def genuineKeys(self): # noqa: N802
return list(self._Dictionary.keys())
def iteritems(self):
return six.iteritems(self._Dictionary)
return self._Dictionary.items()
def iterkeys(self):
return six.iterkeys(self._Dictionary)
return self._Dictionary.keys()
def itervalues(self):
return six.itervalues(self._Dictionary)
return self._Dictionary.values()
# }}}
......@@ -405,7 +568,7 @@ class DependentDictionary(object):
# {{{ assertive accessors
def one(iterable):
def one(iterable: Iterable[T]) -> T:
"""Return the first entry of *iterable*. Assert that *iterable* has only
that one entry.
"""
......@@ -413,7 +576,7 @@ def one(iterable):
try:
v = next(it)
except StopIteration:
raise ValueError("empty iterable passed to 'one()'")
raise ValueError("empty iterable passed to 'one()'") from None
def no_more():
try:
......@@ -427,12 +590,15 @@ def one(iterable):
return v
def is_single_valued(iterable, equality_pred=operator.eq):
def is_single_valued(
iterable: Iterable[T],
equality_pred: Callable[[T, T], bool] = operator.eq
) -> bool:
it = iter(iterable)
try:
first_item = next(it)
except StopIteration:
raise ValueError("empty iterable passed to 'single_valued()'")
raise ValueError("empty iterable passed to 'single_valued()'") from None
for other_item in it:
if not equality_pred(other_item, first_item):
......@@ -448,7 +614,10 @@ def all_roughly_equal(iterable, threshold):
equality_pred=lambda a, b: abs(a-b) < threshold)
def single_valued(iterable, equality_pred=operator.eq):
def single_valued(
iterable: Iterable[T],
equality_pred: Callable[[T, T], bool] = operator.eq
) -> T:
"""Return the first entry of *iterable*; Assert that other entries
are the same with the first entry of *iterable*.
"""
......@@ -456,7 +625,7 @@ def single_valued(iterable, equality_pred=operator.eq):
try:
first_item = next(it)
except StopIteration:
raise ValueError("empty iterable passed to 'single_valued()'")
raise ValueError("empty iterable passed to 'single_valued()'") from None
def others_same():
for other_item in it:
......@@ -472,7 +641,7 @@ def single_valued(iterable, equality_pred=operator.eq):
# {{{ memoization / attribute storage
def memoize(*args, **kwargs):
def memoize(*args: F, **kwargs: Any) -> F:
"""Stores previously computed function values in a cache.
Two keyword-only arguments are supported:
......@@ -484,11 +653,13 @@ def memoize(*args, **kwargs):
which computes and returns the cache key.
"""
use_kw = bool(kwargs.pop('use_kwargs', False))
use_kw = bool(kwargs.pop("use_kwargs", False))
default_key_func: Callable[..., Any] | None
if use_kw:
def default_key_func(*inner_args, **inner_kwargs):
return inner_args, frozenset(six.iteritems(inner_kwargs))
return inner_args, frozenset(inner_kwargs.items())
else:
default_key_func = None
......@@ -496,241 +667,291 @@ def memoize(*args, **kwargs):
if kwargs:
raise TypeError(
"memoize received unexpected keyword arguments: %s"
% ", ".join(list(kwargs.keys())))
"memoize received unexpected keyword arguments: {}".format(
", ".join(kwargs.keys())))
if key_func is not None:
@my_decorator
def _deco(func, *args, **kwargs):
# by Michele Simionato
# http://www.phyast.pitt.edu/~micheles/python/
key = key_func(*args, **kwargs)
try:
return func._memoize_dic[key] # pylint: disable=protected-access
except AttributeError:
# _memoize_dic doesn't exist yet.
result = func(*args, **kwargs)
func._memoize_dic = {key: result} # pylint: disable=protected-access
return result
except KeyError:
result = func(*args, **kwargs)
func._memoize_dic[key] = result # pylint: disable=protected-access
return result
def _decorator(func):
def wrapper(*args, **kwargs):
key = key_func(*args, **kwargs)
try:
return func._memoize_dic[key]
except AttributeError:
# _memoize_dic doesn't exist yet.
result = func(*args, **kwargs)
func._memoize_dic = {key: result}
return result
except KeyError:
result = func(*args, **kwargs)
func._memoize_dic[key] = result
return result
from functools import update_wrapper
update_wrapper(wrapper, func)
return wrapper
else:
@my_decorator
def _deco(func, *args):
# by Michele Simionato
# http://www.phyast.pitt.edu/~micheles/python/
try:
return func._memoize_dic[args] # pylint: disable=protected-access
except AttributeError:
# _memoize_dic doesn't exist yet.
result = func(*args)
func._memoize_dic = {args: result} # pylint:disable=protected-access
return result
except KeyError:
result = func(*args)
func._memoize_dic[args] = result # pylint: disable=protected-access
return result
def _decorator(func):
def wrapper(*args):
try:
return func._memoize_dic[args]
except AttributeError:
# _memoize_dic doesn't exist yet.
result = func(*args)
func._memoize_dic = {args: result}
return result
except KeyError:
result = func(*args)
func._memoize_dic[args] = result
return result
from functools import update_wrapper
update_wrapper(wrapper, func)
return wrapper
if not args:
return _deco
return _decorator # type: ignore
if callable(args[0]) and len(args) == 1:
return _deco(args[0])
return _decorator(args[0])
raise TypeError(
"memoize received unexpected position arguments: %s" % args)
f"memoize received unexpected position arguments: {args}")
FunctionValueCache = memoize
class _HasKwargs(object):
class _HasKwargs:
pass
def memoize_on_first_arg(function, cache_dict_name=None):
def memoize_on_first_arg(
function: Callable[Concatenate[T, P], R], *,
cache_dict_name: str | None = None) -> Callable[Concatenate[T, P], R]:
"""Like :func:`memoize_method`, but for functions that take the object
to do memoization as first argument.
in which do memoization information is stored as first argument.
Supports cache deletion via ``function_name.clear_cache(self)``.
.. note::
*clear_cache* support requires Python 2.5 or newer.
"""
if cache_dict_name is None:
cache_dict_name = intern("_memoize_dic_"
+ function.__module__ + function.__name__)
cache_dict_name = intern(
f"_memoize_dic_{function.__module__}{function.__name__}"
)
def wrapper(obj, *args, **kwargs):
def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R:
if kwargs:
key = (_HasKwargs, frozenset(six.iteritems(kwargs))) + args
key = (_HasKwargs, frozenset(kwargs.items()), *args)
else:
key = args
assert cache_dict_name is not None
try:
return getattr(obj, cache_dict_name)[key]
except AttributeError:
result = function(obj, *args, **kwargs)
setattr(obj, cache_dict_name, {key: result})
return result
attribute_error = True
except KeyError:
result = function(obj, *args, **kwargs)
getattr(obj, cache_dict_name)[key] = result
attribute_error = False
result = function(obj, *args, **kwargs)
if attribute_error:
object.__setattr__(obj, cache_dict_name, {key: result})
return result
getattr(obj, cache_dict_name)[key] = result
return result
def clear_cache(obj):
delattr(obj, cache_dict_name)
object.__delattr__(obj, cache_dict_name)
from functools import update_wrapper
new_wrapper = update_wrapper(wrapper, function)
new_wrapper.clear_cache = clear_cache
# type-ignore because mypy has a point here, stuffing random attributes
# into the function's dict is moderately sketchy.
new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
return new_wrapper
def memoize_method(method):
def memoize_method(
method: Callable[Concatenate[T, P], R]
) -> Callable[Concatenate[T, P], R]:
"""Supports cache deletion via ``method_name.clear_cache(self)``.
.. note::
*clear_cache* support requires Python 2.5 or newer.
"""
.. versionchanged:: 2021.2
return memoize_on_first_arg(method, intern("_memoize_dic_"+method.__name__))
Can memoize methods on classes that do not allow setting attributes
(e.g. by overwriting ``__setattr__``), e.g. frozen :mod:`dataclasses`.
"""
return memoize_on_first_arg(method,
cache_dict_name=intern(f"_memoize_dic_{method.__name__}"))
def memoize_method_with_uncached(uncached_args=None, uncached_kwargs=None):
"""Supports cache deletion via ``method_name.clear_cache(self)``.
:arg uncached_args: a list of argument numbers
(0-based, not counting 'self' argument)
"""
class keyed_memoize_on_first_arg(Generic[T, P, R]): # noqa: N801
"""Like :func:`memoize_method`, but for functions that take the object
in which memoization information is stored as first argument.
if uncached_args is None:
uncached_args = []
if uncached_kwargs is None:
uncached_kwargs = set()
Supports cache deletion via ``function_name.clear_cache(self)``.
# delete starting from the end
uncached_args = sorted(uncached_args, reverse=True)
uncached_kwargs = list(uncached_kwargs)
:arg key: A function receiving the same arguments as the decorated function
which computes and returns the cache key.
:arg cache_dict_name: The name of the `dict` attribute in the instance
used to hold the cache.
def parametrized_decorator(method):
cache_dict_name = intern("_memoize_dic_"+method.__name__)
.. versionadded :: 2020.3
"""
def wrapper(self, *args, **kwargs):
cache_args = list(args)
cache_kwargs = kwargs.copy()
def __init__(self,
key: Callable[P, Hashable], *,
cache_dict_name: str | None = None) -> None:
self.key = key
self.cache_dict_name = cache_dict_name
for i in uncached_args:
if i < len(cache_args):
cache_args.pop(i)
def _default_cache_dict_name(self,
function: Callable[Concatenate[T, P], R]) -> str:
return intern(f"_memoize_dic_{function.__module__}{function.__name__}")
cache_args = tuple(cache_args)
def __call__(
self, function: Callable[Concatenate[T, P], R]
) -> Callable[Concatenate[T, P], R]:
cache_dict_name = self.cache_dict_name
key = self.key
if kwargs:
for name in uncached_kwargs:
cache_kwargs.pop(name, None)
if cache_dict_name is None:
cache_dict_name = self._default_cache_dict_name(function)
key = (
(_HasKwargs, frozenset(six.iteritems(cache_kwargs)))
+ cache_args)
else:
key = cache_args
def wrapper(obj: T, *args: P.args, **kwargs: P.kwargs) -> R:
cache_key = key(*args, **kwargs)
assert cache_dict_name is not None
try:
return getattr(self, cache_dict_name)[key]
return getattr(obj, cache_dict_name)[cache_key]
except AttributeError:
result = method(self, *args, **kwargs)
setattr(self, cache_dict_name, {key: result})
result = function(obj, *args, **kwargs)
object.__setattr__(obj, cache_dict_name, {cache_key: result})
return result
except KeyError:
result = method(self, *args, **kwargs)
getattr(self, cache_dict_name)[key] = result
result = function(obj, *args, **kwargs)
getattr(obj, cache_dict_name)[cache_key] = result
return result
def clear_cache(self):
delattr(self, cache_dict_name)
def clear_cache(obj):
object.__delattr__(obj, cache_dict_name)
if sys.version_info >= (2, 5):
from functools import update_wrapper
new_wrapper = update_wrapper(wrapper, method)
new_wrapper.clear_cache = clear_cache
from functools import update_wrapper
new_wrapper = update_wrapper(wrapper, function)
new_wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
return new_wrapper
return parametrized_decorator
class keyed_memoize_method(keyed_memoize_on_first_arg): # noqa: N801
"""Like :class:`memoize_method`, but additionally uses a function *key* to
compute the key under which the function result is stored.
def memoize_method_nested(inner):
"""Adds a cache to a function nested inside a method. The cache is attached
to *memoize_cache_context* (if it exists) or *self* in the outer (method)
namespace.
Supports cache deletion via ``method_name.clear_cache(self)``.
Requires Python 2.5 or newer.
"""
:arg key: A function receiving the same arguments as the decorated function
which computes and returns the cache key.
from warnings import warn
warn("memoize_method_nested is deprecated. Use @memoize_in(self, 'identifier') "
"instead", DeprecationWarning, stacklevel=2)
.. versionadded :: 2020.3
from functools import wraps
cache_dict_name = intern("_memoize_inner_dic_%s_%s_%d"
% (inner.__name__, inner.__code__.co_filename,
inner.__code__.co_firstlineno))
.. versionchanged:: 2021.2
from inspect import currentframe
outer_frame = currentframe().f_back
cache_context = outer_frame.f_locals.get("memoize_cache_context")
if cache_context is None:
cache_context = outer_frame.f_locals.get("self")
Can memoize methods on classes that do not allow setting attributes
(e.g. by overwriting ``__setattr__``), e.g. frozen :mod:`dataclasses`.
"""
def _default_cache_dict_name(self, function):
return intern(f"_memoize_dic_{function.__name__}")
try:
cache_dict = getattr(cache_context, cache_dict_name)
except AttributeError:
cache_dict = {}
setattr(cache_context, cache_dict_name, cache_dict)
@wraps(inner)
def new_inner(*args):
try:
return cache_dict[args]
except KeyError:
result = inner(*args)
cache_dict[args] = result
return result
class memoize_in: # noqa: N801
"""Adds a cache to the function it decorates. The cache is attached
to *container* and must be uniquely specified by *identifier* (i.e.
all functions using the same *container* and *identifier* will be using
the same cache). The decorated function may only receive positional
arguments.
return new_inner
.. note::
This function works well on nested functions, which
do not have stable global identifiers.
class memoize_in(object): # noqa
"""Adds a cache to a function nested inside a method. The cache is attached
to *object*.
.. versionchanged :: 2020.3
Requires Python 2.5 or newer.
*identifier* no longer needs to be a :class:`str`,
but it needs to be hashable.
.. versionchanged:: 2021.2.1
Can now use instances of classes as *container* that do not allow
setting attributes (e.g. by overwriting ``__setattr__``),
e.g. frozen :mod:`dataclasses`.
"""
def __init__(self, container, identifier):
key = "_pytools_memoize_in_dict_for_"+identifier
def __init__(self, container: Any, identifier: Hashable) -> None:
try:
self.cache_dict = getattr(container, key)
memoize_in_dict = container._pytools_memoize_in_dict
except AttributeError:
self.cache_dict = {}
setattr(container, key, self.cache_dict)
memoize_in_dict = {}
object.__setattr__(container, "_pytools_memoize_in_dict",
memoize_in_dict)
def __call__(self, inner):
from functools import wraps
self.cache_dict = memoize_in_dict.setdefault(identifier, {})
def __call__(self, inner: Callable[P, R]) -> Callable[P, R]:
@wraps(inner)
def new_inner(*args):
def new_inner(*args: P.args, **kwargs: P.kwargs) -> R:
assert not kwargs
try:
return self.cache_dict[args]
except KeyError:
result = inner(*args)
result = inner(*args, **kwargs)
self.cache_dict[args] = result
return result
return new_inner
class keyed_memoize_in(Generic[P]): # noqa: N801
"""Like :class:`memoize_in`, but additionally uses a function *key* to
compute the key under which the function result is memoized.
:arg key: A function receiving the same arguments as the decorated function
which computes and returns the cache key.
.. versionadded :: 2021.2.1
"""
def __init__(self,
container: Any, identifier: Hashable,
key: Callable[P, Hashable]) -> None:
try:
memoize_in_dict = container._pytools_keyed_memoize_in_dict
except AttributeError:
memoize_in_dict = {}
object.__setattr__(container, "_pytools_keyed_memoize_in_dict",
memoize_in_dict)
self.cache_dict = memoize_in_dict.setdefault(identifier, {})
self.key = key
def __call__(self, inner: Callable[P, R]) -> Callable[P, R]:
@wraps(inner)
def new_inner(*args: P.args, **kwargs: P.kwargs) -> R:
assert not kwargs
key = self.key(*args, **kwargs)
try:
return self.cache_dict[key]
except KeyError:
result = inner(*args, **kwargs)
self.cache_dict[key] = result
return result
return new_inner
# }}}
......@@ -768,7 +989,7 @@ def monkeypatch_class(_name, bases, namespace):
assert len(bases) == 1, "Exactly one base class required"
base = bases[0]
for name, value in six.iteritems(namespace):
for name, value in namespace.items():
if name != "__metaclass__":
setattr(base, name, value)
return base
......@@ -779,22 +1000,24 @@ def monkeypatch_class(_name, bases, namespace):
# {{{ generic utilities
def add_tuples(t1, t2):
return tuple([t1v + t2v for t1v, t2v in zip(t1, t2)])
return tuple(t1v + t2v for t1v, t2v in zip(t1, t2, strict=True))
def negate_tuple(t1):
return tuple([-t1v for t1v in t1])
return tuple(-t1v for t1v in t1)
def shift(vec, dist):
"""Return a copy of C{vec} shifted by C{dist}.
"""Return a copy of *vec* shifted by *dist* such that
.. code:: python
@postcondition: C{shift(a, i)[j] == a[(i+j) % len(a)]}
shift(a, i)[j] == a[(i+j) % len(a)]
"""
result = vec[:]
N = len(vec) # noqa
N = len(vec) # noqa: N806
dist = dist % N
# modulo only returns positive distances!
......@@ -816,8 +1039,7 @@ def flatten(iterable):
Example: Turn [[a,b,c],[d,e,f]] into [a,b,c,d,e,f].
"""
for sublist in iterable:
for j in sublist:
yield j
yield from sublist
def general_sum(sequence):
......@@ -826,7 +1048,7 @@ def general_sum(sequence):
def linear_combination(coefficients, vectors):
result = coefficients[0] * vectors[0]
for c, v in zip(coefficients[1:], vectors[1:]):
for c, v in zip(coefficients[1:], vectors[1:], strict=True):
result += c*v
return result
......@@ -875,21 +1097,17 @@ def partition2(iterable):
return part_true, part_false
def product(iterable):
def product(iterable: Iterable[Any]) -> Any:
from operator import mul
return reduce(mul, iterable, 1)
all = six.moves.builtins.all # pylint: disable=redefined-builtin
any = six.moves.builtins.any # pylint: disable=redefined-builtin
def reverse_dictionary(the_dict):
result = {}
for key, value in six.iteritems(the_dict):
for key, value in the_dict.items():
if value in result:
raise RuntimeError(
"non-reversible mapping, duplicate key '%s'" % value)
f"non-reversible mapping, duplicate key '{value}'")
result[value] = key
return result
......@@ -944,7 +1162,7 @@ def find_max_where(predicate, prec=1e-5, initial_guess=1, fail_bound=1e38):
if mag > fail_bound:
raise RuntimeError("predicate appears to be true "
"everywhere, up to %g" % fail_bound)
f"everywhere, up to {fail_bound:g}")
lower_true = mag/2
upper_false = mag
......@@ -991,7 +1209,7 @@ def argmin2(iterable, return_value=False):
try:
current_argmin, current_min = next(it)
except StopIteration:
raise ValueError("argmin of empty iterable")
raise ValueError("argmin of empty iterable") from None
for arg, item in it:
if item < current_min:
......@@ -1000,8 +1218,7 @@ def argmin2(iterable, return_value=False):
if return_value:
return current_argmin, current_min
else:
return current_argmin
return current_argmin
def argmax2(iterable, return_value=False):
......@@ -1009,7 +1226,7 @@ def argmax2(iterable, return_value=False):
try:
current_argmax, current_max = next(it)
except StopIteration:
raise ValueError("argmax of empty iterable")
raise ValueError("argmax of empty iterable") from None
for arg, item in it:
if item > current_max:
......@@ -1018,8 +1235,7 @@ def argmax2(iterable, return_value=False):
if return_value:
return current_argmax, current_max
else:
return current_argmax
return current_argmax
def argmin(iterable):
......@@ -1034,10 +1250,15 @@ def argmax(iterable):
# {{{ cartesian products etc.
def cartesian_product(list1, list2):
for i in list1:
for j in list2:
yield (i, j)
def cartesian_product(*args):
if len(args) == 1:
for arg in args[0]:
yield (arg,)
return
first = args[:-1]
for prod in cartesian_product(*first):
for i in args[-1]:
yield (*prod, i)
def distinct_pairs(list1, list2):
......@@ -1071,7 +1292,7 @@ def average(iterable):
s = next(it)
count = 1
except StopIteration:
raise ValueError("empty average")
raise ValueError("empty average") from None
for value in it:
s = s + value
......@@ -1102,13 +1323,10 @@ class VarianceAggregator:
if self.entire_pop:
if self.n == 0:
return None
else:
return self.m2/self.n
else:
if self.n <= 1:
return None
else:
return self.m2/(self.n - 1)
return self.m2/self.n
if self.n <= 1:
return None
return self.m2/(self.n - 1)
def variance(iterable, entire_pop):
......@@ -1135,19 +1353,23 @@ def wandering_element(length, wanderer=1, landscape=0):
def indices_in_shape(shape):
from warnings import warn
warn("indices_in_shape is deprecated. You should prefer numpy.ndindex.",
DeprecationWarning, stacklevel=2)
if isinstance(shape, int):
shape = (shape,)
if not shape:
yield ()
elif len(shape) == 1:
for i in range(0, shape[0]):
for i in range(shape[0]):
yield (i,)
else:
remainder = shape[1:]
for i in range(0, shape[0]):
for i in range(shape[0]):
for rest in indices_in_shape(remainder):
yield (i,)+rest
yield (i, *rest)
def generate_nonnegative_integer_tuples_below(n, length=None, least=0):
......@@ -1182,7 +1404,7 @@ def generate_decreasing_nonnegative_tuples_summing_to(
yield ()
elif length == 1:
if n <= max_value:
#print "MX", n, max_value
# print "MX", n, max_value
yield (n,)
else:
return
......@@ -1191,10 +1413,10 @@ def generate_decreasing_nonnegative_tuples_summing_to(
max_value = n
for i in range(min_value, max_value+1):
#print "SIG", sig, i
# print "SIG", sig, i
for remainder in generate_decreasing_nonnegative_tuples_summing_to(
n-i, length-1, min_value, i):
yield (i,) + remainder
yield (i, *remainder)
def generate_nonnegative_integer_tuples_summing_to_at_most(n, length):
......@@ -1209,25 +1431,11 @@ def generate_nonnegative_integer_tuples_summing_to_at_most(n, length):
for i in range(n+1):
for remainder in generate_nonnegative_integer_tuples_summing_to_at_most(
n-i, length-1):
yield remainder + (i,)
def generate_all_nonnegative_integer_tuples(length, least=0):
assert length >= 0
current_max = least
while True:
for max_pos in range(length):
for prebase in generate_nonnegative_integer_tuples_below(
current_max, max_pos, least):
for postbase in generate_nonnegative_integer_tuples_below(
current_max+1, length-max_pos-1, least):
yield prebase + [current_max] + postbase
current_max += 1
yield (*remainder, i)
# backwards compatibility
generate_positive_integer_tuples_below = generate_nonnegative_integer_tuples_below
generate_all_positive_integer_tuples = generate_all_nonnegative_integer_tuples
def _pos_and_neg_adaptor(tuple_iter):
......@@ -1247,11 +1455,6 @@ def generate_all_integer_tuples_below(n, length, least_abs=0):
n, length, least_abs))
def generate_all_integer_tuples(length, least_abs=0):
return _pos_and_neg_adaptor(generate_all_nonnegative_integer_tuples(
length, least_abs))
def generate_permutations(original):
"""Generate all permutations of the list *original*.
......@@ -1262,7 +1465,7 @@ def generate_permutations(original):
else:
for perm_ in generate_permutations(original[1:]):
for i in range(len(perm_)+1):
#nb str[0:1] works in both string and list contexts
# nb str[0:1] works in both string and list contexts
yield perm_[:i] + original[0:1] + perm_[i:]
......@@ -1285,182 +1488,313 @@ def enumerate_basic_directions(dimensions):
# }}}
# {{{ index mangling
# {{{ graph algorithms
def get_read_from_map_from_permutation(original, permuted):
"""With a permutation given by C{original} and C{permuted},
generate a list C{rfm} of indices such that
C{permuted[i] == original[rfm[i]]}.
from pytools.graph import a_star as a_star_moved
Requires that the permutation can be inferred from
C{original} and C{permuted}.
.. doctest ::
a_star = MovedFunctionDeprecationWrapper(a_star_moved)
>>> for p1 in generate_permutations(range(5)):
... for p2 in generate_permutations(range(5)):
... rfm = get_read_from_map_from_permutation(p1, p2)
... p2a = [p1[rfm[i]] for i in range(len(p1))]
... assert p2 == p2a
"""
from warnings import warn
warn("get_read_from_map_from_permutation is deprecated and will be "
"removed in 2019", DeprecationWarning, stacklevel=2)
# }}}
assert len(original) == len(permuted)
where_in_original = dict(
(original[i], i) for i in range(len(original)))
assert len(where_in_original) == len(original)
return tuple(where_in_original[pi] for pi in permuted)
# {{{ formatting
def get_write_to_map_from_permutation(original, permuted):
"""With a permutation given by C{original} and C{permuted},
generate a list C{wtm} of indices such that
C{permuted[wtm[i]] == original[i]}.
# {{{ table formatting
class Table:
"""An ASCII table generator.
Requires that the permutation can be inferred from
C{original} and C{permuted}.
.. automethod:: __init__
.. automethod:: add_row
.. doctest::
.. autoproperty:: nrows
.. autoproperty:: ncolumns
>>> for p1 in generate_permutations(range(5)):
... for p2 in generate_permutations(range(5)):
... wtm = get_write_to_map_from_permutation(p1, p2)
... p2a = [0] * len(p2)
... for i, oi in enumerate(p1):
... p2a[wtm[i]] = oi
... assert p2 == p2a
.. automethod:: __str__
.. automethod:: github_markdown
.. automethod:: csv
.. automethod:: latex
.. automethod:: text_without_markup
"""
from warnings import warn
warn("get_write_to_map_from_permutation is deprecated and will be "
"removed in 2019", DeprecationWarning, stacklevel=2)
assert len(original) == len(permuted)
def __init__(self, alignments: tuple[str, ...] | None = None) -> None:
"""Create a new :class:`Table`.
where_in_permuted = dict(
(permuted[i], i) for i in range(len(permuted)))
:arg alignments: A :class:`tuple` of alignments of each column:
``"l"``, ``"c"``, or ``"r"``, for left, center, and right
alignment, respectively). Columns which have no alignment specifier
will use the last specified alignment. For example, with
``alignments=("l", "r")``, the third and all following
columns will use right alignment.
"""
assert len(where_in_permuted) == len(permuted)
return tuple(where_in_permuted[oi] for oi in original)
if alignments is None:
alignments = ("l",)
else:
if any(a not in ("l", "c", "r") for a in alignments):
raise ValueError(f"alignments are ('l', 'c', 'r'): {alignments}")
# }}}
alignments = tuple(alignments)
self.rows: list[tuple[str, ...]] = []
self.alignments = alignments
# {{{ graph algorithms
@property
def nrows(self) -> int:
"""The number of rows currently in the table."""
return len(self.rows)
def a_star( # pylint: disable=too-many-locals
initial_state, goal_state, neighbor_map,
estimate_remaining_cost=None,
get_step_cost=lambda x, y: 1
):
"""
With the default cost and heuristic, this amounts to Dijkstra's algorithm.
"""
@property
def ncolumns(self) -> int:
"""The number of columns currently in the table."""
return len(self.rows[0])
from heapq import heappop, heappush
def add_row(self, row: tuple[Any, ...]) -> None:
"""Add *row* to the table. Note that all rows must have the same number
of columns."""
if self.rows and len(row) != self.ncolumns:
raise ValueError(
f"tried to add a row with {len(row)} columns to "
f"a table with {self.ncolumns} columns")
self.rows.append(tuple(str(i) for i in row))
def _get_alignments(self) -> tuple[str, ...]:
# NOTE: If not all alignments were specified, extend alignments with the
# last alignment specified
return (
self.alignments
+ (self.alignments[-1],) * (self.ncolumns - len(self.alignments))
)[:self.ncolumns]
def _get_column_widths(self, rows) -> tuple[int, ...]:
return tuple(
max(len(row[i]) for row in rows) for i in range(self.ncolumns)
)
if estimate_remaining_cost is None:
def estimate_remaining_cost(x): # pylint: disable=function-redefined
if x != goal_state:
return 1
else:
return 0
def __str__(self) -> str:
"""
Returns a string representation of the table.
class AStarNode(object):
__slots__ = ["state", "parent", "path_cost"]
.. doctest ::
def __init__(self, state, parent, path_cost):
self.state = state
self.parent = parent
self.path_cost = path_cost
>>> tbl = Table(alignments=['l', 'r', 'l'])
>>> tbl.add_row([1, '|'])
>>> tbl.add_row([10, '20||'])
>>> print(tbl)
1 | |
---+------
10 | 20||
inf = float("inf")
init_remcost = estimate_remaining_cost(initial_state)
assert init_remcost != inf
"""
if not self.rows:
return ""
queue = [(init_remcost, AStarNode(initial_state, parent=None, path_cost=0))]
visited_states = set()
alignments = self._get_alignments()
col_widths = self._get_column_widths(self.rows)
while queue:
_, top = heappop(queue)
visited_states.add(top.state)
lines = [" | ".join([
cell.center(cwidth) if align == "c"
else cell.ljust(cwidth) if align == "l"
else cell.rjust(cwidth)
for cell, cwidth, align in zip(row, col_widths, alignments, strict=True)])
for row in self.rows]
lines[1:1] = ["+".join("-" * (cwidth + 1 + (i > 0))
for i, cwidth in enumerate(col_widths))]
if top.state == goal_state:
result = []
it = top
while it is not None:
result.append(it.state)
it = it.parent
return result[::-1]
return "\n".join(lines)
for state in neighbor_map[top.state]:
if state in visited_states:
continue
def github_markdown(self) -> str:
r"""Returns a string representation of the table formatted as
`GitHub-Flavored Markdown.
<https://docs.github.com/en/github/writing-on-github/organizing-information-with-tables>`__
remaining_cost = estimate_remaining_cost(state)
if remaining_cost == inf:
continue
step_cost = get_step_cost(top, state)
.. doctest ::
estimated_path_cost = top.path_cost+step_cost+remaining_cost
heappush(queue,
(estimated_path_cost,
AStarNode(state, top, path_cost=top.path_cost + step_cost)))
>>> tbl = Table(alignments=['l', 'r', 'l'])
>>> tbl.add_row([1, '|'])
>>> tbl.add_row([10, '20||'])
>>> print(tbl.github_markdown())
1 | \|
:--|-------:
10 | 20\|\|
raise RuntimeError("no solution")
"""
if not self.rows:
return ""
def escape(cell: str) -> str:
# Pipe symbols ('|') must be replaced
return cell.replace("|", "\\|")
rows = [tuple(escape(cell) for cell in row) for row in self.rows]
alignments = self._get_alignments()
col_widths = self._get_column_widths(rows)
lines = [" | ".join([
cell.center(cwidth) if align == "c"
else cell.ljust(cwidth) if align == "l"
else cell.rjust(cwidth)
for cell, cwidth, align in zip(row, col_widths, alignments, strict=True)])
for row in rows]
lines[1:1] = ["|".join(
(":" + "-" * (cwidth - 1 + (i > 0)) + ":") if align == "c"
else (":" + "-" * (cwidth + (i > 0))) if align == "l"
else ("-" * (cwidth + (i > 0)) + ":")
for i, (cwidth, align) in enumerate(
zip(col_widths, alignments, strict=True)))]
# }}}
return "\n".join(lines)
def csv(self,
dialect: str = "excel",
csv_kwargs: dict[str, Any] | None = None) -> str:
"""Returns a string containing a CSV representation of the table.
# {{{ formatting
:arg dialect: String passed to :func:`csv.writer`.
:arg csv_kwargs: Dict of arguments passed to :func:`csv.writer`.
# {{{ table formatting
.. doctest ::
class Table:
"""An ASCII table generator.
>>> tbl = Table()
>>> tbl.add_row([1, ","])
>>> tbl.add_row([10, 20])
>>> print(tbl.csv())
1,","
10,20
"""
.. automethod:: add_row
.. automethod:: __str__
.. automethod:: latex
"""
if not self.rows:
return ""
def __init__(self):
self.rows = []
import csv
import io
def add_row(self, row):
self.rows.append([str(i) for i in row])
if csv_kwargs is None:
csv_kwargs = {}
def __str__(self):
columns = len(self.rows[0])
col_widths = [max(len(row[i]) for row in self.rows)
for i in range(columns)]
# Default is "\r\n"
if "lineterminator" not in csv_kwargs:
csv_kwargs["lineterminator"] = "\n"
lines = [
"|".join([cell.ljust(col_width)
for cell, col_width in zip(row, col_widths)])
for row in self.rows]
lines[1:1] = ["+".join("-"*col_width
for col_width in col_widths)]
return "\n".join(lines)
output = io.StringIO()
writer = csv.writer(output, dialect, **csv_kwargs)
writer.writerows(self.rows)
return output.getvalue().rstrip(csv_kwargs["lineterminator"])
def latex(self,
skip_lines: int = 0,
hline_after: tuple[int, ...] | None = None) -> str:
r"""Returns a string containing the rows of a LaTeX representation of
the table.
:arg skip_lines: number of lines to skip at the start of the table.
:arg hline_after: list of row indices after which to add an ``hline``
(the indices must subtract *skip_lines*, if non-zero).
.. doctest::
>>> tbl = Table()
>>> tbl.add_row([0, "skipped"])
>>> tbl.add_row([1, "apple"])
>>> tbl.add_row([2, "pear"])
>>> print(tbl.latex(skip_lines=1))
1 & apple \\
2 & pear \\
"""
if not self.rows:
return ""
def latex(self, skip_lines=0, hline_after=None):
if hline_after is None:
hline_after = []
hline_after = ()
lines = []
for row_nr, row in list(enumerate(self.rows))[skip_lines:]:
lines.append(" & ".join(row)+r" \\")
for row_nr, row in enumerate(self.rows[skip_lines:]):
lines.append(fr"{' & '.join(row)} \\")
if row_nr in hline_after:
lines.append(r"\hline")
return "\n".join(lines)
def text_without_markup(self) -> str:
"""Returns a string representation of the table without markup.
.. doctest::
>>> tbl = Table()
>>> tbl.add_row([0, "orange"])
>>> tbl.add_row([1111, "apple"])
>>> tbl.add_row([2, "pear"])
>>> print(tbl.text_without_markup())
0 orange
1111 apple
2 pear
"""
if not self.rows:
return ""
alignments = self._get_alignments()
col_widths = self._get_column_widths(self.rows)
lines = [" ".join([
cell.center(cwidth) if align == "c"
else cell.ljust(cwidth) if align == "l"
else cell.rjust(cwidth)
for cell, cwidth, align in zip(row, col_widths, alignments, strict=True)])
for row in self.rows]
# Remove the extra space added by the last cell
lines = [line.rstrip() for line in lines]
return "\n".join(lines)
def merge_tables(*tables: Table,
skip_columns: tuple[int, ...] | None = None) -> Table:
"""
:arg skip_columns: a :class:`tuple` of column indices to skip in all the
tables except the first one.
"""
if len(tables) == 1:
return tables[0]
if any(tables[0].nrows != tbl.nrows for tbl in tables[1:]):
raise ValueError("tables do not have the same number of rows")
if isinstance(skip_columns, int):
skip_columns = (skip_columns,)
def remove_columns(i, row):
if i == 0 or skip_columns is None:
return row
return tuple(
entry for i, entry in enumerate(row) if i not in skip_columns
)
alignments = sum((
remove_columns(i, tbl._get_alignments())
for i, tbl in enumerate(tables)
), ())
result = Table(alignments=alignments)
for i in range(tables[0].nrows):
row = []
for j, tbl in enumerate(tables):
row.extend(remove_columns(j, tbl.rows[i]))
result.add_row(tuple(row))
return result
# }}}
# {{{ histogram formatting
def string_histogram( # pylint: disable=too-many-arguments,too-many-locals
def string_histogram(
iterable, min_value=None, max_value=None,
bin_count=20, width=70, bin_starts=None, use_unicode=True):
if bin_starts is None:
......@@ -1476,9 +1810,9 @@ def string_histogram( # pylint: disable=too-many-arguments,too-many-locals
from bisect import bisect
for value in iterable:
if max_value is not None and value > max_value or value < bin_starts[0]:
if (max_value is not None and value > max_value) or value < bin_starts[0]:
from warnings import warn
warn("string_histogram: out-of-bounds value ignored")
warn("string_histogram: out-of-bounds value ignored", stacklevel=2)
else:
bin_nr = bisect(bin_starts, value)-1
try:
......@@ -1487,28 +1821,27 @@ def string_histogram( # pylint: disable=too-many-arguments,too-many-locals
print(value, bin_nr, bin_starts)
raise
from math import floor, ceil
from math import ceil, floor
if use_unicode:
def format_bar(cnt):
scaled = cnt*width/max_count
full = int(floor(scaled))
eighths = int(ceil((scaled-full)*8))
full = floor(scaled)
eighths = ceil((scaled-full)*8)
if eighths:
return full*six.unichr(0x2588) + six.unichr(0x2588+(8-eighths))
else:
return full*six.unichr(0x2588)
return full*chr(0x2588) + chr(0x2588+(8-eighths))
return full*chr(0x2588)
else:
def format_bar(cnt):
return int(ceil(cnt*width/max_count))*"#"
return ceil(cnt*width/max_count)*"#"
max_count = max(bins)
total_count = sum(bins)
return "\n".join("%9g |%9d | %3.0f %% | %s" % (
return "\n".join("{:9g} |{:9d} | {:3.0f} % | {}".format(
bin_start,
bin_value,
bin_value/total_count*100,
format_bar(bin_value))
for bin_start, bin_value in zip(bin_starts, bins))
for bin_start, bin_value in zip(bin_starts, bins, strict=True))
# }}}
......@@ -1521,29 +1854,30 @@ def word_wrap(text, width, wrap_using="\n"):
breaks are posix newlines (``\n``).
"""
space_or_break = [" ", wrap_using]
return reduce(lambda line, word: '%s%s%s' %
(line,
space_or_break[(len(line)-line.rfind('\n')-1
+ len(word.split('\n', 1)[0])
>= width)],
word),
text.split(' ')
)
return reduce(lambda line, word: "{}{}{}".format(
line,
space_or_break[(
len(line) - line.rfind("\n") - 1
+ len(word.split("\n", 1)[0])
) >= width],
word),
text.split(" ")
)
# }}}
# {{{ command line interfaces -------------------------------------------------
# {{{ command line interfaces
def _exec_arg(arg, execenv):
import os
if os.access(arg, os.F_OK):
exec(compile(open(arg, "r"), arg, 'exec'), execenv)
exec(compile(open(arg), arg, "exec"), execenv)
else:
exec(compile(arg, "<command line>", 'exec'), execenv)
exec(compile(arg, "<command line>", "exec"), execenv)
class CPyUserInterface(object):
class CPyUserInterface:
class Parameters(Record):
pass
......@@ -1557,7 +1891,7 @@ class CPyUserInterface(object):
self.doc = doc
def show_usage(self, progname):
print("usage: %s <FILE-OR-STATEMENTS>" % progname)
print(f"usage: {progname} <FILE-OR-STATEMENTS>")
print()
print("FILE-OR-STATEMENTS may either be Python statements of the form")
print("'variable1 = value1; variable2 = value2' or the name of a file")
......@@ -1567,16 +1901,16 @@ class CPyUserInterface(object):
print()
print("The following variables are recognized:")
for v in sorted(self.variables):
print(" %s = %s" % (v, self.variables[v]))
print(f" {v} = {self.variables[v]}")
if v in self.doc:
print(" %s" % self.doc[v])
print(f" {self.doc[v]}")
print()
print("The following constants are supplied:")
for c in sorted(self.constants):
print(" %s = %s" % (c, self.constants[c]))
print(f" {c} = {self.constants[c]}")
if c in self.doc:
print(" %s" % self.doc[c])
print(f" {self.doc[c]}")
def gather(self, argv=None):
if argv is None:
......@@ -1603,11 +1937,10 @@ class CPyUserInterface(object):
- set(self.constants.keys())):
if not (added_key.startswith("user_") or added_key.startswith("_")):
raise ValueError(
"invalid setup key: '%s' "
"(user variables must start with 'user_' or '_')"
% added_key)
f"invalid setup key: '{added_key}' "
"(user variables must start with 'user_' or '_')")
result = self.Parameters(dict((key, execenv[key]) for key in self.variables))
result = self.Parameters({key: execenv[key] for key in self.variables})
self.validate(result)
return result
......@@ -1617,28 +1950,10 @@ class CPyUserInterface(object):
# }}}
# {{{ code maintenance
class MovedFunctionDeprecationWrapper:
def __init__(self, f):
self.f = f
def __call__(self, *args, **kwargs):
from warnings import warn
warn("This function is deprecated. Use %s.%s instead." % (
self.f.__module__, self.f.__name__),
DeprecationWarning, stacklevel=2)
return self.f(*args, **kwargs)
# }}}
# {{{ debugging
class StderrToStdout(object):
class StderrToStdout:
def __enter__(self):
# pylint: disable=attribute-defined-outside-init
self.stderr_backup = sys.stderr
sys.stderr = sys.stdout
......@@ -1647,7 +1962,24 @@ class StderrToStdout(object):
del self.stderr_backup
def typedump(val, max_seq=5, special_handlers=None):
def typedump(val: Any, max_seq: int = 5,
special_handlers: Mapping[type, Callable] | None = None,
fully_qualified_name: bool = True) -> str:
"""
Return a string representation of the type of *val*, recursing into
iterable objects.
:arg val: The object for which the type should be returned.
:arg max_seq: For iterable objects, the maximum number of elements to
include in the return string. Lower this value if you get a
:class:`RecursionError`.
:arg special_handlers: An optional mapping of specific types to special
handlers.
:arg fully_qualified_name: Return fully qualified names, that is,
include module names and use ``__qualname__`` instead of ``__name__``.
:returns: A string representation of the type of *val*.
"""
if special_handlers is None:
special_handlers = {}
......@@ -1658,30 +1990,42 @@ def typedump(val, max_seq=5, special_handlers=None):
else:
return hdlr(val)
def objname(obj: Any) -> str:
if type(obj).__module__ == "builtins":
if fully_qualified_name:
return type(obj).__qualname__
return type(obj).__name__
if fully_qualified_name:
return type(obj).__module__ + "." + type(obj).__qualname__
return type(obj).__name__
# Special handling for 'str' since it is also iterable
if isinstance(val, str):
return "str"
try:
len(val)
except TypeError:
return type(val).__name__
return objname(val)
else:
if isinstance(val, dict):
return "{%s}" % (
", ".join(
"%r: %s" % (str(k), typedump(v))
for k, v in six.iteritems(val)))
f"{str(k)!r}: {typedump(v)}"
for k, v in val.items()))
try:
if len(val) > max_seq:
return "%s(%s,...)" % (
type(val).__name__,
",".join(typedump(x, max_seq, special_handlers)
for x in val[:max_seq]))
else:
return "%s(%s)" % (
type(val).__name__,
",".join(typedump(x, max_seq, special_handlers)
for x in val))
t = ",".join(typedump(x, max_seq, special_handlers)
for x in val[:max_seq])
return f"{objname(val)}({t},...)"
t = ",".join(typedump(x, max_seq, special_handlers)
for x in val)
return f"{objname(val)}({t})"
except TypeError:
return val.__class__.__name__
return objname(val)
def invoke_editor(s, filename="edit.txt", descr="the file"):
......@@ -1703,10 +2047,9 @@ def invoke_editor(s, filename="edit.txt", descr="the file"):
else:
print("(Set the EDITOR environment variable to be "
"dropped directly into an editor next time.)")
input("Edit %s at %s now, then hit [Enter]:"
% (descr, full_name))
input(f"Edit {descr} at {full_name} now, then hit [Enter]:")
inf = open(full_name, "r")
inf = open(full_name)
result = inf.read()
inf.close()
......@@ -1717,7 +2060,7 @@ def invoke_editor(s, filename="edit.txt", descr="the file"):
# {{{ progress bars
class ProgressBar(object): # pylint: disable=too-many-instance-attributes
class ProgressBar:
"""
.. automethod:: draw
.. automethod:: progress
......@@ -1760,12 +2103,13 @@ class ProgressBar(object): # pylint: disable=too-many-instance-attributes
self.speed_meas_start_done = self.done
if self.time_per_step is not None:
eta_str = "%7.1fs " % max(
0, (self.total-self.done) * self.time_per_step)
eta_str = "{:7.1f}s ".format(
max(0, (self.total-self.done) * self.time_per_step)
)
else:
eta_str = "?"
sys.stderr.write("%-20s [%s] ETA %s\r" % (
sys.stderr.write("{:<20} [{}] ETA {}\r".format(
self.description,
squares*"#"+(self.length-squares)*" ",
eta_str))
......@@ -1798,11 +2142,11 @@ class ProgressBar(object): # pylint: disable=too-many-instance-attributes
def assert_not_a_file(name):
import os
if os.access(name, os.F_OK):
raise IOError("file `%s' already exists" % name)
raise OSError(f"file `{name}' already exists")
def add_python_path_relative_to_script(rel_path):
from os.path import dirname, join, abspath
from os.path import abspath, dirname, join
script_name = sys.argv[0]
rel_script_dir = dirname(script_name)
......@@ -1818,26 +2162,15 @@ def common_dtype(dtypes, default=None):
dtypes = list(dtypes)
if dtypes:
return argmax2((dtype, dtype.num) for dtype in dtypes)
else:
if default is not None:
return default
else:
raise ValueError(
"cannot find common dtype of empty dtype list")
if default is not None:
return default
raise ValueError(
"cannot find common dtype of empty dtype list")
def to_uncomplex_dtype(dtype):
import numpy
if dtype == numpy.complex64:
return numpy.float32
elif dtype == numpy.complex128:
return numpy.float64
if dtype == numpy.float32:
return numpy.float32
elif dtype == numpy.float64:
return numpy.float64
else:
raise TypeError("unrecgonized dtype '%s'" % dtype)
import numpy as np
return np.array(1, dtype=dtype).real.dtype.type
def match_precision(dtype, dtype_to_match):
......@@ -1850,13 +2183,10 @@ def match_precision(dtype, dtype_to_match):
if dtype_is_complex:
if tgt_is_double:
return numpy.dtype(numpy.complex128)
else:
return numpy.dtype(numpy.complex64)
else:
if tgt_is_double:
return numpy.dtype(numpy.float64)
else:
return numpy.dtype(numpy.float32)
return numpy.dtype(numpy.complex64)
if tgt_is_double:
return numpy.dtype(numpy.float64)
return numpy.dtype(numpy.float32)
# }}}
......@@ -1868,17 +2198,21 @@ def generate_unique_names(prefix):
try_num = 0
while True:
yield "%s_%d" % (prefix, try_num)
yield f"{prefix}_{try_num}"
try_num += 1
def generate_numbered_unique_names(prefix, num=None):
UNIQUE_NAME_GEN_COUNTER_RE = re.compile(r"^(?P<based_on>\w+)_(?P<counter>\d+)$")
def generate_numbered_unique_names(
prefix: str, num: int | None = None) -> Iterable[tuple[int, str]]:
if num is None:
yield (0, prefix)
num = 0
while True:
name = "%s_%d" % (prefix, num)
name = f"{prefix}_{num}"
num += 1
yield (num, name)
......@@ -1887,25 +2221,39 @@ generate_unique_possibilities = MovedFunctionDeprecationWrapper(
generate_unique_names)
class UniqueNameGenerator(object):
class UniqueNameGenerator:
"""
Class that creates a new :class:`str` on each :meth:`__call__` that is
unique to the generator.
.. automethod:: __init__
.. automethod:: is_name_conflicting
.. automethod:: add_name
.. automethod:: add_names
.. automethod:: __call__
"""
def __init__(self, existing_names=None, forced_prefix=""):
def __init__(self,
existing_names: Collection[str] | None = None,
forced_prefix: str = ""):
"""
Create a new :class:`UniqueNameGenerator`.
:arg existing_names: a :class:`set` of existing names that will be
skipped when generating new names.
:arg forced_prefix: all generated :class:`str` have this prefix.
"""
if existing_names is None:
existing_names = set()
self.existing_names = existing_names.copy()
self.existing_names = set(existing_names)
self.forced_prefix = forced_prefix
self.prefix_to_counter = {}
self.prefix_to_counter: dict[str, int] = {}
def is_name_conflicting(self, name):
def is_name_conflicting(self, name: str) -> bool:
"""Returns *True* if *name* conflicts with an existing :class:`str`."""
return name in self.existing_names
def _name_added(self, name):
def _name_added(self, name: str) -> None:
"""Callback to alert subclasses when a name has been added.
.. note::
......@@ -1913,27 +2261,52 @@ class UniqueNameGenerator(object):
This will not get called for the names in the *existing_names*
argument to :meth:`__init__`.
"""
pass
def add_name(self, name):
if self.is_name_conflicting(name):
raise ValueError("name '%s' conflicts with existing names")
def add_name(self, name: str, *, conflicting_ok: bool = False) -> None:
"""
:arg conflicting_ok: A flag to dictate the behavior when *name* is
conflicting with the set of existing names. If *True*, a conflict
is silently passed. If *False*, a :class:`ValueError` is raised on
encountering a conflict.
"""
if (not conflicting_ok) and self.is_name_conflicting(name):
raise ValueError(f"name '{name}' conflicts with existing names")
if not name.startswith(self.forced_prefix):
raise ValueError("name '%s' does not start with required prefix")
raise ValueError(
f"name '{name}' does not start with required prefix "
f"'{self.forced_prefix}'")
self.existing_names.add(name)
self._name_added(name)
def add_names(self, names):
def add_names(self, names: Iterable[str],
*,
conflicting_ok: bool = False) -> None:
"""
:arg conflicting_ok: Plainly passed to :meth:`UniqueNameGenerator.add_name`.
"""
for name in names:
self.add_name(name)
self.add_name(name, conflicting_ok=conflicting_ok)
def __call__(self, based_on="id"):
def __call__(self, based_on: str = "id") -> str:
"""Returns a new unique name."""
based_on = self.forced_prefix + based_on
counter = self.prefix_to_counter.get(based_on, None)
for counter, var_name in generate_numbered_unique_names(based_on, counter):
# {{{ try to get counter from based_on if not already present
if counter is None:
counter_match = UNIQUE_NAME_GEN_COUNTER_RE.match(based_on)
if counter_match:
based_on = counter_match.groupdict()["based_on"]
counter = int(counter_match.groupdict()["counter"])
# }}}
for counter, var_name in generate_numbered_unique_names(based_on, counter): # noqa: B020,B007
if not self.is_name_conflicting(var_name):
break
......@@ -1950,13 +2323,11 @@ class UniqueNameGenerator(object):
# {{{ recursion limit
class MinRecursionLimit(object):
class MinRecursionLimit:
def __init__(self, min_rec_limit):
self.min_rec_limit = min_rec_limit
def __enter__(self):
# pylint: disable=attribute-defined-outside-init
self.prev_recursion_limit = sys.getrecursionlimit()
new_limit = max(self.prev_recursion_limit, self.min_rec_limit)
sys.setrecursionlimit(new_limit)
......@@ -1989,8 +2360,14 @@ def download_from_web_if_not_present(url, local_name=None):
local_name = basename(url)
if not exists(local_name):
from six.moves.urllib.request import urlopen
with urlopen(url) as inf:
from urllib.request import Request, urlopen
from pytools.version import VERSION_TEXT
req = Request(url, headers={
"User-Agent": f"pytools/{VERSION_TEXT}"
})
with urlopen(req) as inf:
contents = inf.read()
with open(local_name, "wb") as outf:
......@@ -2001,11 +2378,11 @@ def download_from_web_if_not_present(url, local_name=None):
# {{{ find git revisions
def find_git_revision(tree_root): # pylint: disable=too-many-locals
def find_git_revision(tree_root):
# Keep this routine self-contained so that it can be copy-pasted into
# setup.py.
from os.path import join, exists, abspath
from os.path import abspath, exists, join
tree_root = abspath(tree_root)
if not exists(join(tree_root, ".git")):
......@@ -2016,23 +2393,22 @@ def find_git_revision(tree_root): # pylint: disable=too-many-locals
# https://github.com/numpy/numpy/blob/055ce3e90b50b5f9ef8cf1b8641c42e391f10735/setup.py#L70-L92
import os
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
for k in ["SYSTEMROOT", "PATH", "HOME"]:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
env["LANGUAGE"] = "C"
env["LANG"] = "C"
env["LC_ALL"] = "C"
from subprocess import Popen, PIPE, STDOUT
from subprocess import PIPE, STDOUT, Popen
p = Popen(["git", "rev-parse", "HEAD"], shell=False,
stdin=PIPE, stdout=PIPE, stderr=STDOUT, close_fds=True,
cwd=tree_root, env=env)
(git_rev, _) = p.communicate()
if sys.version_info >= (3,):
git_rev = git_rev.decode()
git_rev = git_rev.decode()
git_rev = git_rev.rstrip()
......@@ -2040,7 +2416,7 @@ def find_git_revision(tree_root): # pylint: disable=too-many-locals
assert retcode is not None
if retcode != 0:
from warnings import warn
warn("unable to find git revision")
warn("unable to find git revision", stacklevel=1)
return None
return git_rev
......@@ -2048,7 +2424,7 @@ def find_git_revision(tree_root): # pylint: disable=too-many-locals
def find_module_git_revision(module_file, n_levels_up):
from os.path import dirname, join
tree_root = join(*([dirname(module_file)] + [".." * n_levels_up]))
tree_root = join(*([dirname(module_file), ".." * n_levels_up]))
return find_git_revision(tree_root)
......@@ -2077,7 +2453,10 @@ def reshaped_view(a, newshape):
# {{{ process timer
class ProcessTimer(object):
SUPPORTS_PROCESS_TIME = True
class ProcessTimer:
"""Measures elapsed wall time and process time.
.. automethod:: __enter__
......@@ -2089,20 +2468,16 @@ class ProcessTimer(object):
.. attribute:: wall_elapsed
.. attribute:: process_elapsed
Only available in Python 3.3+.
.. versionadded:: 2018.5
"""
def __init__(self):
import time
if sys.version_info >= (3, 3):
self.perf_counter_start = time.perf_counter()
self.process_time_start = time.process_time()
self.perf_counter_start = time.perf_counter()
self.process_time_start = time.process_time()
else:
import timeit
self.time_start = timeit.default_timer()
self.wall_elapsed = None
self.process_elapsed = None
def __enter__(self):
return self
......@@ -2111,24 +2486,39 @@ class ProcessTimer(object):
self.done()
def done(self):
# pylint: disable=attribute-defined-outside-init
import time
if sys.version_info >= (3, 3):
self.wall_elapsed = time.perf_counter() - self.perf_counter_start
self.process_elapsed = time.process_time() - self.process_time_start
self.wall_elapsed = time.perf_counter() - self.perf_counter_start
self.process_elapsed = time.process_time() - self.process_time_start
else:
import timeit
self.wall_elapsed = timeit.default_timer() - self.time_start
self.process_elapsed = None
def __str__(self):
cpu = self.process_elapsed / self.wall_elapsed
return f"{self.wall_elapsed:.2f}s wall {cpu:.2f}x CPU"
def __repr__(self):
wall = self.wall_elapsed
process = self.process_elapsed
return (f"{type(self).__name__}"
f"(wall_elapsed={wall!r}s, process_elapsed={process!r}s)")
# }}}
# {{{ log utilities
class ProcessLogger(object): # pylint: disable=too-many-instance-attributes
def _log_start_if_long(logger, sleep_duration, done_indicator,
noisy_level, description):
from time import sleep
sleep(sleep_duration)
if not done_indicator[0]:
logger.log(
noisy_level, "%s: started %.gs ago",
description,
sleep_duration)
class ProcessLogger:
"""Logs the completion time of a (presumably) lengthy process to :mod:`logging`.
Only uses a high log level if the process took perceptible time.
......@@ -2140,7 +2530,7 @@ class ProcessLogger(object): # pylint: disable=too-many-instance-attributes
default_noisy_level = logging.INFO
def __init__( # pylint: disable=too-many-arguments
def __init__(
self, logger, description,
silent_level=None, noisy_level=None, long_threshold_seconds=None):
self.logger = logger
......@@ -2152,50 +2542,67 @@ class ProcessLogger(object): # pylint: disable=too-many-instance-attributes
0.3 if long_threshold_seconds is None else long_threshold_seconds)
self.logger.log(self.silent_level, "%s: start", self.description)
self.is_done = False
self._done_indicator = [False]
import threading
self.late_start_log_thread = threading.Thread(target=self._log_start_if_long)
# Do not delay interpreter exit if thread not finished.
self.late_start_log_thread.daemon = True
self.late_start_log_thread.start()
self.late_start_log_thread = threading.Thread(
target=_log_start_if_long,
args=(logger, 10*self.long_threshold_seconds, self._done_indicator,
self.noisy_level, self.description),
self.timer = ProcessTimer()
# Do not delay interpreter exit if thread not finished.
daemon=True)
def _log_start_if_long(self):
from time import sleep
# https://github.com/firedrakeproject/firedrake/issues/1422
# Starting a thread may irrecoverably break various environments,
# e.g. MPI.
#
# Since the late-start logging is an optional 'quality-of-life'
# feature for interactive use, do not do it unless there is (weak)
# evidence of interactive use.
import sys
if sys.stdin is None:
# Can happen, e.g., if pudb is controlling the console.
use_late_start_logging = False
elif hasattr(sys.stdin, "closed") and not sys.stdin.closed:
# can query stdin.isatty() only if stdin's open
use_late_start_logging = sys.stdin.isatty()
else:
use_late_start_logging = False
sleep_duration = 10*self.long_threshold_seconds
sleep(sleep_duration)
import os
if os.environ.get("PYTOOLS_LOG_NO_THREADS", ""):
use_late_start_logging = False
if not self.is_done:
self.logger.log(
self.noisy_level, "%s: started %.gs ago",
self.description,
sleep_duration)
if use_late_start_logging:
try:
self.late_start_log_thread.start()
except RuntimeError:
# https://github.com/firedrakeproject/firedrake/issues/1422
#
# Starting a thread may fail in various environments, e.g. MPI.
# Since the late-start logging is an optional 'quality-of-life'
# feature for interactive use, tolerate failures of it without
# warning.
pass
def done( # pylint: disable=keyword-arg-before-vararg
self.timer = ProcessTimer()
def done(
self, extra_msg=None, *extra_fmt_args):
self.timer.done()
self.is_done = True
wall_elapsed = self.timer.wall_elapsed
process_elapsed = self.timer.process_elapsed
self._done_indicator[0] = True
completion_level = (
self.noisy_level
if wall_elapsed > self.long_threshold_seconds
if self.timer.wall_elapsed > self.long_threshold_seconds
else self.silent_level)
if process_elapsed is not None:
msg = "%s: completed (%.2fs wall, %.1fx CPU)"
fmt_args = [self.description, wall_elapsed, process_elapsed/wall_elapsed]
else:
msg = "%s: completed (%f.2s wall)"
fmt_args = [self.description, wall_elapsed]
msg = "%s: completed (%s)"
fmt_args = [self.description, str(self.timer)]
if extra_msg:
msg += ": " + extra_msg
msg = f"{msg}: {extra_msg}"
fmt_args.extend(extra_fmt_args)
self.logger.log(completion_level, msg, *fmt_args)
......@@ -2211,20 +2618,25 @@ class DebugProcessLogger(ProcessLogger):
default_noisy_level = logging.DEBUG
class log_process(object): # noqa: N801
class log_process: # noqa: N801
"""A decorator that uses :class:`ProcessLogger` to log data about calls
to the wrapped function.
.. automethod:: __init__
.. automethod:: __call__
"""
def __init__(self, logger, description=None):
def __init__(self, logger, description=None, long_threshold_seconds=None):
self.logger = logger
self.description = description
self.long_threshold_seconds = long_threshold_seconds
def __call__(self, wrapped):
def wrapper(*args, **kwargs):
with ProcessLogger(
self.logger,
self.description or wrapped.__name__):
self.description or wrapped.__name__,
long_threshold_seconds=self.long_threshold_seconds):
return wrapped(*args, **kwargs)
from functools import update_wrapper
......@@ -2235,6 +2647,413 @@ class log_process(object): # noqa: N801
# }}}
# {{{ sorting in natural order
def natorder(item):
"""Return a key for natural order string comparison.
See :func:`natsorted`.
.. versionadded:: 2020.1
"""
import re
result = []
for (int_val, string_val) in re.findall(r"(\d+)|(\D+)", item):
if int_val:
result.append(int(int_val))
# Tie-breaker in case of leading zeros in *int_val*. Longer values
# compare smaller to preserve order of numbers in decimal notation,
# e.g., "1.001" < "1.01"
# (cf. https://github.com/sourcefrog/natsort)
result.append(-len(int_val))
else:
result.append(string_val)
return result
def natsorted(iterable, key=None, reverse=False):
"""Sort using natural order [1]_, as opposed to lexicographic order.
Example::
>>> sorted(["_10", "_1", "_9"]) == ["_1", "_10", "_9"]
True
>>> natsorted(["_10", "_1", "_9"]) == ["_1", "_9", "_10"]
True
:arg iterable: an iterable to be sorted. It must only have strings, unless
*key* is specified.
:arg key: if provided, a key function that returns strings for ordering
using natural order.
:arg reverse: if *True*, sorts in descending order.
:returns: a sorted list
.. [1] https://en.wikipedia.org/wiki/Natural_sort_order
.. versionadded:: 2020.1
"""
if key is None:
def key(x):
return x
return sorted(iterable, key=lambda y: natorder(key(y)), reverse=reverse)
# }}}
# {{{ resolve_name
# https://github.com/python/cpython/commit/1ed61617a4a6632905ad6a0b440cd2cafb8b6414
_DOTTED_WORDS = r"[a-z_]\w*(\.[a-z_]\w*)*"
_NAME_PATTERN = re.compile(f"^({_DOTTED_WORDS})(:({_DOTTED_WORDS})?)?$", re.I)
del _DOTTED_WORDS
def resolve_name(name):
"""A backport of :func:`pkgutil.resolve_name` (added in Python 3.9).
.. versionadded:: 2021.1.2
"""
from warnings import warn
warn("'pytools.resolve_name' is deprecated and will be removed in 2024. "
"Use 'pkgutil.resolve_name' from the standard library instead.",
DeprecationWarning, stacklevel=2)
import pkgutil
return pkgutil.resolve_name(name)
# }}}
# {{{ unordered_hash
def unordered_hash(hash_instance: Any,
iterable: Iterable[Any],
hash_constructor: Callable[[], Any] | None = None) -> Any:
"""Using a hash algorithm given by the parameter-less constructor
*hash_constructor*, return a hash object whose internal state
depends on the entries of *iterable*, but not their order. If *hash*
is the instance returned by evaluating ``hash_constructor()``, then
the each entry *i* of the iterable must permit ``hash.update(i)`` to
succeed. An example of *hash_constructor* is ``hashlib.sha256``
from :mod:`hashlib`. ``hash.digest_size`` must also be defined.
If *hash_constructor* is not provided, ``hash_instance.name`` is
used to deduce it.
:returns: the updated *hash_instance*.
.. warning::
The construction used in this function is likely not cryptographically
secure. Do not use this function in a security-relevant context.
.. versionadded:: 2021.2
"""
if hash_constructor is None:
import hashlib
from functools import partial
hash_constructor = partial(hashlib.new, hash_instance.name)
assert hash_constructor is not None
h_int = 0
for i in iterable:
h_i = hash_constructor()
h_i.update(i)
# Using sys.byteorder (for efficiency) here technically makes the
# hash system-dependent (which it should not be), however the
# effect of this is undone by the to_bytes conversion below, while
# left invariant by the intervening XOR operations (which do not
# mix adjacent bits).
h_int = h_int ^ int.from_bytes(h_i.digest(), sys.byteorder)
hash_instance.update(h_int.to_bytes(hash_instance.digest_size, sys.byteorder))
return hash_instance
# }}}
# {{{ sphere_sample
def sphere_sample_equidistant(npoints_approx: int, r: float = 1.0):
"""Generate points regularly distributed on a sphere
based on https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf.
:returns: an :class:`~numpy.ndarray` of shape ``(3, npoints)``, where
``npoints`` does not generally equal *npoints_approx*.
"""
import numpy as np
points: list[np.ndarray] = []
count = 0
a = 4 * np.pi / npoints_approx
d = a ** 0.5
M_theta = int(np.ceil(np.pi / d)) # noqa: N806
d_theta = np.pi / M_theta
d_phi = a / d_theta
for m in range(M_theta):
theta = np.pi * (m + 0.5) / M_theta
M_phi = int(np.ceil(2 * np.pi * np.sin(theta) / d_phi)) # noqa: N806
for n in range(M_phi):
phi = 2 * np.pi * n / M_phi
points.append(np.array([
r * np.sin(theta) * np.cos(phi),
r * np.sin(theta) * np.sin(phi),
r * np.cos(theta)
]))
count += 1
# add poles
for i in range(3):
for sign in [-1, +1]:
pole = np.zeros(3)
pole[i] = r * sign
points.append(pole)
return np.array(points).T.copy()
# NOTE: each tuple contains ``(epsilon, max_npoints)``
_SPHERE_FIBONACCI_OFFSET = (
(0.33, 24), (1.33, 177), (3.33, 890),
(10, 11000), (27, 39000), (75, 600000), (214, float("inf")),
)
def sphere_sample_fibonacci(
npoints: int, r: float = 1.0, *,
optimize: str | None = None):
"""Generate points on a sphere based on an offset Fibonacci lattice from [2]_.
.. [2] http://extremelearning.com.au/how-to-evenly-distribute-points-on-a-sphere-more-effectively-than-the-canonical-fibonacci-lattice/
:param optimize: takes the values: *None* to use the standard Fibonacci
lattice, ``"minimum"`` to minimize the nearest neighbor distances in the
lattice and ``"average"`` to minimize the average distances in the
lattice.
:returns: an :class:`~numpy.ndarray` of shape ``(3, npoints)``.
"""
import numpy as np
if optimize is None:
epsilon = 0.5
elif optimize == "minimum":
epsilon, _ = next(o for o in _SPHERE_FIBONACCI_OFFSET if npoints < o[1])
elif optimize == "average":
epsilon = 0.36
else:
raise ValueError(f"unknown 'optimize' choice: '{optimize}'")
golden_ratio = (1 + np.sqrt(5)) / 2
n = np.arange(npoints)
phi = 2.0 * np.pi * n / golden_ratio
theta = np.arccos(1.0 - 2.0 * (n + epsilon) / (npoints + 2 * epsilon - 1))
return np.stack([
r * np.sin(theta) * np.cos(phi),
r * np.sin(theta) * np.sin(phi),
r * np.cos(theta)
])
# }}}
# {{{ strtobool
def strtobool(val: str | None, default: bool | None = None) -> bool:
"""Convert a string representation of truth to True or False.
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Uppercase versions are
also accepted. If *default* is None, raises ValueError if *val* is anything
else. If *val* is None and *default* is not None, returns *default*.
Based on :func:`distutils.util.strtobool`.
:param val: Value to convert.
:param default: Value to return if *val* is None.
:returns: Truth value of *val*.
"""
if val is None and default is not None:
return default
if val is None:
raise ValueError(f"invalid truth value '{val}'. "
"Valid values are ('y', 'yes', 't', 'true', 'on', '1') "
"for 'True' and ('n', 'no', 'f', 'false', 'off', '0') "
"for 'False'. Uppercase versions are also accepted.")
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
if val in ("n", "no", "f", "false", "off", "0"):
return False
raise ValueError(f"invalid truth value '{val}'. "
"Valid values are ('y', 'yes', 't', 'true', 'on', '1') "
"for 'True' and ('n', 'no', 'f', 'false', 'off', '0') "
"for 'False'. Uppercase versions are also accepted.")
# }}}
# {{{ to_identifier
def to_identifier(s: str) -> str:
"""Convert a string to a valid Python identifier, by removing
non-alphanumeric, non-underscore characters, and prepending an underscore
if the string starts with a numeric character.
:param s: The string to convert to an identifier.
:returns: The converted string.
"""
if s.isidentifier():
return s
s = "".join(c for c in s if c.isalnum() or c == "_")
if len(s) == 0:
return "_"
if s[0].isdigit():
s = "_" + s
return s
# }}}
# {{{ unique
def unique(seq: Iterable[T]) -> Collection[T]:
"""Return unique elements in *seq*, removing all duplicates. The internal
order of the elements is preserved. See also
:func:`itertools.groupby` (which removes consecutive duplicates)."""
return dict.fromkeys(seq)
def unique_difference(*args: Iterable[T]) -> Collection[T]:
r"""Return unique elements that are in the first iterable in *\*args* but not
in any of the others. The internal order of the elements is preserved."""
if not args:
return []
res = dict.fromkeys(args[0])
for seq in args[1:]:
for item in seq:
if item in res:
del res[item]
return res
def unique_intersection(*args: Iterable[T]) -> Collection[T]:
r"""Return unique elements that are common to all iterables in *\*args*.
The internal order of the elements is preserved."""
if not args:
return []
res = dict.fromkeys(args[0])
for seq in args[1:]:
seq = set(seq)
res = {item: None for item in res if item in seq}
return res
def unique_union(*args: Iterable[T]) -> Collection[T]:
r"""Return unique elements that are in any iterable in *\*args*.
The internal order of the elements is preserved."""
if not args:
return []
res: dict[T, None] = {}
for seq in args:
for item in seq:
if item not in res:
res[item] = None
return res
# }}}
@dataclass_transform(frozen_default=True)
def opt_frozen_dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool | None = None,
match_args: bool = True,
kw_only: bool = False,
slots: bool = False,
# Added in 3.11.
weakref_slot: bool = False,
) -> Callable[[type[T]], type[T]]:
"""Like :func:`dataclasses.dataclass`, but marks the dataclass frozen
only if :data:`__debug__` is active. Frozen dataclasses have a ~20%
cost penalty (on creation, from having to call :meth:`object.__setattr__`) that
this decorator avoids when the interpreter runs with "optimization"
enabled.
The resulting dataclass supports hashing, even when it is not actually frozen,
if *unsafe_hash* is left at the default or set to *True*.
.. note::
Python prevents non-frozen dataclasses from inheriting from frozen ones,
and vice versa. To ensure frozen-ness is applied predictably in all
scenarios (mainly :data:`__debug__` on and off), it is strongly recommended
that all dataclasses inheriting from ones with this decorator *also*
use this decorator. There are no run-time checks to make sure of this.
.. versionadded:: 2024.1.18
"""
def map_cls(cls: type[T]) -> type[T]:
# This ensures that the resulting dataclass is hashable with and without
# __debug__, unless the user overrides unsafe_hash or provides their own
# __hash__ method.
if unsafe_hash is None:
if (eq
and not __debug__
and "__hash__" not in cls.__dict__):
loc_unsafe_hash = True
else:
loc_unsafe_hash = False
else:
loc_unsafe_hash = unsafe_hash
dc_extra_kwargs: dict[str, bool] = {}
if weakref_slot:
if sys.version_info < (3, 11):
raise TypeError("weakref_slot is not available before Python 3.11")
dc_extra_kwargs["weakref_slot"] = weakref_slot
from dataclasses import dataclass
return dataclass(
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=loc_unsafe_hash,
frozen=__debug__,
match_args=match_args,
kw_only=kw_only,
slots=slots,
**dc_extra_kwargs,
)(cls)
return map_cls
def _test():
import doctest
doctest.testmod()
......
from __future__ import absolute_import
import six
from __future__ import annotations
def _cp(src, dest):
......@@ -22,7 +21,7 @@ def get_timestamp():
return datetime.now().strftime("%Y-%m-%d-%H%M%S")
class BatchJob(object):
class BatchJob:
def __init__(self, moniker, main_file, aux_files=(), timestamp=None):
import os
import os.path
......@@ -44,10 +43,9 @@ class BatchJob(object):
os.makedirs(self.path)
runscript = open("%s/run.sh" % self.path, "w")
runscript = open(f"{self.path}/run.sh", "w")
import sys
runscript.write("%s %s setup.cpy"
% (sys.executable, main_file))
runscript.write(f"{sys.executable} {main_file} setup.cpy")
runscript.close()
from os.path import basename
......@@ -65,7 +63,7 @@ class BatchJob(object):
setup.close()
class INHERIT(object): # noqa
class INHERIT:
pass
......@@ -80,20 +78,20 @@ class GridEngineJob(BatchJob):
from os import getenv
env = dict(env)
for var, value in six.iteritems(env):
for var, value in env.items():
if value is INHERIT:
value = getenv(var)
args += ["-v", "%s=%s" % (var, value)]
args += ["-v", f"{var}={value}"]
if memory_megs is not None:
args.extend(["-l", "mem=%d" % memory_megs])
args.extend(["-l", f"mem={memory_megs}"])
args.extend(extra_args)
subproc = Popen(["qsub"] + args + ["run.sh"], cwd=self.path)
subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path)
if subproc.wait() != 0:
raise RuntimeError("Process submission of %s failed" % self.moniker)
raise RuntimeError(f"Process submission of {self.moniker} failed")
class PBSJob(BatchJob):
......@@ -106,32 +104,31 @@ class PBSJob(BatchJob):
]
if memory_megs is not None:
args.extend(["-l", "pmem=%dmb" % memory_megs])
args.extend(["-l", f"pmem={memory_megs}mb"])
from os import getenv
env = dict(env)
for var, value in six.iteritems(env):
for var, value in env.items():
if value is INHERIT:
value = getenv(var)
args += ["-v", "%s=%s" % (var, value)]
args += ["-v", f"{var}={value}"]
args.extend(extra_args)
subproc = Popen(["qsub"] + args + ["run.sh"], cwd=self.path)
subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path)
if subproc.wait() != 0:
raise RuntimeError("Process submission of %s failed" % self.moniker)
raise RuntimeError(f"Process submission of {self.moniker} failed")
def guess_job_class():
from subprocess import Popen, PIPE, STDOUT
from subprocess import PIPE, STDOUT, Popen
qstat_helplines = Popen(["qstat", "--help"],
stdout=PIPE, stderr=STDOUT).communicate()[0].split("\n")
if qstat_helplines[0].startswith("GE"):
return GridEngineJob
else:
return PBSJob
return PBSJob
class ConstructorPlaceholder:
......@@ -147,11 +144,11 @@ class ConstructorPlaceholder:
return self.kwargs[name]
def __str__(self):
return "%s(%s)" % (self.classname,
return "{}({})".format(self.classname,
",".join(
[str(arg) for arg in self.args]
+ ["%s=%s" % (kw, repr(val))
for kw, val in six.iteritems(self.kwargs)]
+ [f"{kw}={val!r}"
for kw, val in self.kwargs.items()]
)
)
__repr__ = __str__
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
Tools for Source Code Generation
================================
.. autoclass:: CodeGenerator
.. autoclass:: Indentation
.. autofunction:: remove_common_indentation
"""
from typing import Any
# {{{ code generation
# loosely based on
# http://effbot.org/zone/python-code-generator.htm
class CodeGenerator:
"""Language-agnostic functionality for source code generation.
.. automethod:: extend
.. automethod:: get
.. automethod:: add_to_preamble
.. automethod:: __call__
.. automethod:: indent
.. automethod:: dedent
"""
def __init__(self) -> None:
self.preamble: list[str] = []
self.code: list[str] = []
self.level = 0
self.indent_amount = 4
def extend(self, sub_generator: CodeGenerator) -> None:
for line in sub_generator.code:
self.code.append(" "*(self.indent_amount*self.level) + line)
def get(self) -> str:
result = "\n".join(self.code)
if self.preamble:
result = "\n".join(self.preamble) + "\n" + result
return result
def add_to_preamble(self, s: str) -> None:
self.preamble.append(s)
def __call__(self, s: str) -> None:
if not s.strip():
self.code.append("")
else:
if "\n" in s:
s = remove_common_indentation(s)
for line in s.split("\n"):
self.code.append(" "*(self.indent_amount*self.level) + line)
def indent(self) -> None:
self.level += 1
def dedent(self) -> None:
if self.level == 0:
raise RuntimeError("cannot decrease indentation level")
self.level -= 1
class Indentation:
"""A context manager for indentation for use with :class:`CodeGenerator`.
.. attribute:: generator
.. automethod:: __enter__
.. automethod:: __exit__
"""
def __init__(self, generator: CodeGenerator):
self.generator = generator
def __enter__(self) -> None:
self.generator.indent()
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.generator.dedent()
# }}}
# {{{ remove common indentation
def remove_common_indentation(code: str, require_leading_newline: bool = True):
r"""Remove leading indentation from one or more lines of code.
Removes an amount of indentation equal to the indentation level of the first
nonempty line in *code*.
:param code: Input string.
:param require_leading_newline: If *True*, only remove indentation if *code*
starts with ``\n``.
:returns: A copy of *code* stripped of leading common indentation.
"""
if "\n" not in code:
return code
if require_leading_newline and not code.startswith("\n"):
return code
lines = code.split("\n")
while lines[0].strip() == "":
lines.pop(0)
while lines[-1].strip() == "":
lines.pop(-1)
if lines:
base_indent = 0
while lines[0][base_indent] in " \t":
base_indent += 1
for line in lines[1:]:
if line[:base_indent].strip():
raise ValueError("inconsistent indentation")
return "\n".join(line[base_indent:] for line in lines)
# }}}
# vim: foldmethod=marker
from __future__ import absolute_import
"""
.. autofunction:: estimate_order_of_convergence
.. autoclass:: EOCRecorder
.. autofunction:: stringify_eocs
.. autoclass:: PConvergenceVerifier
"""
from __future__ import annotations
import numbers
import numpy as np
from six.moves import range
from six.moves import zip
# {{{ eoc estimation --------------------------------------------------------------
def estimate_order_of_convergence(abscissae, errors):
"""Assuming that abscissae and errors are connected by a law of the form
r"""Assuming that abscissae and errors are connected by a law of the form
error = constant * abscissa ^ (order),
.. math::
\text{Error} = \text{constant} \cdot \text{abscissa }^{\text{order}},
this function finds, in a least-squares sense, the best approximation of
constant and order for the given data set. It returns a tuple (constant, order).
......@@ -22,35 +32,68 @@ def estimate_order_of_convergence(abscissae, errors):
return 10**coefficients[-1], coefficients[-2]
class EOCRecorder(object):
def __init__(self):
self.history = []
class EOCRecorder:
"""
.. automethod:: add_data_point
.. automethod:: estimate_order_of_convergence
.. automethod:: order_estimate
.. automethod:: max_error
.. automethod:: pretty_print
.. automethod:: write_gnuplot_file
"""
def __init__(self) -> None:
self.history: list[tuple[float, float]] = []
def add_data_point(self, abscissa: float, error: float) -> None:
if not (isinstance(abscissa, numbers.Number)
or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())):
raise TypeError(
f"'abscissa' is not a scalar: '{type(abscissa).__name__}'")
if not (isinstance(error, numbers.Number)
or (isinstance(error, np.ndarray) and error.shape == ())):
raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'")
def add_data_point(self, abscissa, error):
self.history.append((abscissa, error))
def estimate_order_of_convergence(self, gliding_mean=None):
def estimate_order_of_convergence(self,
gliding_mean: int | None = None,
) -> np.ndarray:
abscissae = np.array([a for a, e in self.history])
errors = np.array([e for a, e in self.history])
# NOTE: in case any of the errors are exactly 0.0, which
# can give NaNs in `estimate_order_of_convergence`
emax: float = np.amax(errors)
errors += (1 if emax == 0 else emax) * np.finfo(errors.dtype).eps
size = len(abscissae)
if gliding_mean is None:
gliding_mean = size
data_points = size - gliding_mean + 1
result = np.zeros((data_points, 2), float)
result: np.ndarray = np.zeros((data_points, 2), float)
for i in range(data_points):
result[i, 0], result[i, 1] = estimate_order_of_convergence(
abscissae[i:i+gliding_mean], errors[i:i+gliding_mean])
return result
def order_estimate(self):
def order_estimate(self) -> float:
return self.estimate_order_of_convergence()[0, 1]
def max_error(self):
def max_error(self) -> float:
return max(err for absc, err in self.history)
def pretty_print(self, abscissa_label="h", error_label="Error", gliding_mean=2):
def _to_table(self, *,
abscissa_label="h",
error_label="Error",
gliding_mean=2,
abscissa_format="%s",
error_format="%s",
eoc_format="%s"):
from pytools import Table
tbl = Table()
......@@ -58,37 +101,108 @@ class EOCRecorder(object):
gm_eoc = self.estimate_order_of_convergence(gliding_mean)
for i, (absc, err) in enumerate(self.history):
absc_str = abscissa_format % absc
err_str = error_format % err
if i < gliding_mean-1:
tbl.add_row((str(absc), str(err), ""))
eoc_str = ""
else:
tbl.add_row((str(absc), str(err), str(gm_eoc[i-gliding_mean+1, 1])))
eoc_str = eoc_format % (gm_eoc[i - gliding_mean + 1, 1])
tbl.add_row((absc_str, err_str, eoc_str))
if len(self.history) > 1:
return str(tbl) + "\n\nOverall EOC: %s" \
% self.estimate_order_of_convergence()[0, 1]
else:
order = self.estimate_order_of_convergence()[0, 1]
tbl.add_row(("Overall", "", eoc_format % order))
return tbl
def pretty_print(self, *,
abscissa_label: str = "h",
error_label: str = "Error",
gliding_mean: int = 2,
abscissa_format: str = "%s",
error_format: str = "%s",
eoc_format: str = "%s",
table_type: str = "markdown") -> str:
tbl = self._to_table(
abscissa_label=abscissa_label, error_label=error_label,
abscissa_format=abscissa_format,
error_format=error_format,
eoc_format=eoc_format,
gliding_mean=gliding_mean)
if table_type == "markdown":
return tbl.github_markdown()
if table_type == "latex":
return tbl.latex()
if table_type == "ascii":
return str(tbl)
if table_type == "csv":
return tbl.csv()
raise ValueError(f"unknown table type: {table_type}")
def __str__(self):
return self.pretty_print()
def write_gnuplot_file(self, filename):
def write_gnuplot_file(self, filename: str) -> None:
outfile = open(filename, "w")
for absc, err in self.history:
outfile.write("%f %f\n" % (absc, err))
outfile.write(f"{absc:f} {err:f}\n")
result = self.estimate_order_of_convergence()
const = result[0, 0]
order = result[0, 1]
outfile.write("\n")
for absc, err in self.history:
outfile.write("%f %f\n" % (absc, const * absc**(-order)))
for absc, _err in self.history:
outfile.write(f"{absc:f} {const * absc**(-order):f}\n")
def stringify_eocs(*eocs: EOCRecorder,
names: tuple[str, ...] | None = None,
abscissa_label: str = "h",
error_label: str = "Error",
gliding_mean: int = 2,
abscissa_format: str = "%s",
error_format: str = "%s",
eoc_format: str = "%s",
table_type: str = "markdown") -> str:
"""
:arg names: a :class:`tuple` of names to use for the *error_label* of each
*eoc*.
"""
if names is not None and len(names) < len(eocs):
raise ValueError(
f"insufficient names: got {len(names)} names for "
f"{len(eocs)} EOCRecorder instances")
if names is None:
names = tuple(f"{error_label} {i}" for i in range(len(eocs)))
from pytools import merge_tables
tbl = merge_tables(*[eoc._to_table(
abscissa_label=abscissa_label, error_label=name,
abscissa_format=abscissa_format,
error_format=error_format,
eoc_format=eoc_format,
gliding_mean=gliding_mean)
for name, eoc in zip(names, eocs, strict=True)
], skip_columns=(0,))
if table_type == "markdown":
return tbl.github_markdown()
if table_type == "latex":
return tbl.latex()
if table_type == "ascii":
return str(tbl)
if table_type == "csv":
return tbl.csv()
raise ValueError(f"unknown table type: {table_type}")
# }}}
# {{{ p convergence verifier
class PConvergenceVerifier(object):
class PConvergenceVerifier:
def __init__(self):
self.orders = []
self.errors = []
......@@ -102,7 +216,7 @@ class PConvergenceVerifier(object):
tbl = Table()
tbl.add_row(("p", "error"))
for p, err in zip(self.orders, self.errors):
for p, err in zip(self.orders, self.errors, strict=True):
tbl.add_row((str(p), str(err)))
return str(tbl)
......
from __future__ import absolute_import
from __future__ import annotations
import six
from six.moves import range, zip
from typing import IO, TYPE_CHECKING, Any
from pytools import Record
if TYPE_CHECKING:
from collections.abc import Callable, Iterator, Sequence
__doc__ = """
An in-memory relational database table
======================================
.. autoclass:: DataTable
"""
class Row(Record):
pass
class DataTable:
"""An in-memory relational database table."""
"""An in-memory relational database table.
.. automethod:: __init__
.. automethod:: copy
.. automethod:: deep_copy
.. automethod:: join
"""
def __init__(self, column_names, column_data=None):
def __init__(self, column_names: Sequence[str],
column_data: list[Any] | None = None) -> None:
"""Construct a new table, with the given C{column_names}.
@arg column_names: An indexable of column name strings.
@arg column_data: None or a list of tuples of the same length as
C{column_names} indicating an initial set of data.
:arg column_names: An indexable of column name strings.
:arg column_data: None or a list of tuples of the same length as
*column_names* indicating an initial set of data.
"""
if column_data is None:
self.data = []
......@@ -26,64 +43,64 @@ class DataTable:
self.data = column_data
self.column_names = column_names
self.column_indices = dict(
(colname, i) for i, colname in enumerate(column_names))
self.column_indices = {
colname: i for i, colname in enumerate(column_names)}
if len(self.column_indices) != len(self.column_names):
raise RuntimeError("non-unique column names encountered")
def __bool__(self):
def __bool__(self) -> bool:
return bool(self.data)
def __len__(self):
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
def __iter__(self) -> Iterator[list[Any]]:
return self.data.__iter__()
def __str__(self):
def __str__(self) -> str:
"""Return a pretty-printed version of the table."""
def col_width(i):
def col_width(i: int) -> int:
width = len(self.column_names[i])
if self:
width = max(width, max(len(str(row[i])) for row in self.data))
return width
col_widths = [col_width(i) for i in range(len(self.column_names))]
def format_row(row):
def format_row(row: Sequence[str]) -> str:
return "|".join([str(cell).ljust(col_width)
for cell, col_width in zip(row, col_widths)])
for cell, col_width in zip(row, col_widths, strict=True)])
lines = [format_row(self.column_names),
"+".join("-"*col_width for col_width in col_widths)] + \
[format_row(row) for row in self.data]
return "\n".join(lines)
def insert(self, **kwargs):
def insert(self, **kwargs: Any) -> None:
values = [None for i in range(len(self.column_names))]
for key, val in six.iteritems(kwargs):
for key, val in kwargs.items():
values[self.column_indices[key]] = val
self.insert_row(tuple(values))
def insert_row(self, values):
def insert_row(self, values: tuple[Any, ...]) -> None:
assert isinstance(values, tuple)
assert len(values) == len(self.column_names)
self.data.append(values)
def insert_rows(self, rows):
def insert_rows(self, rows: Sequence[tuple[Any, ...]]) -> None:
for row in rows:
self.insert_row(row)
def filtered(self, **kwargs):
def filtered(self, **kwargs: Any) -> DataTable:
if not kwargs:
return self
criteria = tuple(
(self.column_indices[key], value)
for key, value in six.iteritems(kwargs))
for key, value in kwargs.items())
result_data = []
......@@ -99,43 +116,44 @@ class DataTable:
return DataTable(self.column_names, result_data)
def get(self, **kwargs):
def get(self, **kwargs: Any) -> Row:
filtered = self.filtered(**kwargs)
if not filtered:
raise RuntimeError("no matching entry for get()")
if len(filtered) > 1:
raise RuntimeError("more than one matching entry for get()")
return Row(dict(list(zip(self.column_names, filtered.data[0]))))
return Row(dict(zip(self.column_names, filtered.data[0], strict=True)))
def clear(self):
def clear(self) -> None:
del self.data[:]
def copy(self):
def copy(self) -> DataTable:
"""Make a copy of the instance, but leave individual rows untouched.
If the rows are modified later, they will also be modified in the copy.
"""
return DataTable(self.column_names, self.data[:])
def deep_copy(self):
def deep_copy(self) -> DataTable:
"""Make a copy of the instance down to the row level.
The copy's rows may be modified independently from the original.
"""
return DataTable(self.column_names, [row[:] for row in self.data])
def sort(self, columns, reverse=False):
def sort(self, columns: Sequence[str], reverse: bool = False) -> None:
col_indices = [self.column_indices[col] for col in columns]
def mykey(row):
def mykey(row: Sequence[Any]) -> tuple[Any, ...]:
return tuple(
row[col_index]
for col_index in col_indices)
self.data.sort(reverse=reverse, key=mykey)
def aggregated(self, groupby, agg_column, aggregate_func):
def aggregated(self, groupby: Sequence[str], agg_column: str,
aggregate_func: Callable[[Sequence[Any]], Any]) -> DataTable:
gb_indices = [self.column_indices[col] for col in groupby]
agg_index = self.column_indices[agg_column]
......@@ -144,14 +162,14 @@ class DataTable:
result_data = []
# to pacify pyflakes:
last_values = None
agg_values = None
last_values: tuple[Any, ...] = ()
agg_values: list[Row] = []
for row in self.data:
this_values = tuple(row[i] for i in gb_indices)
if first or this_values != last_values:
if not first:
result_data.append(last_values + (aggregate_func(agg_values),))
result_data.append((*last_values, aggregate_func(agg_values)))
agg_values = [row[agg_index]]
last_values = this_values
......@@ -160,14 +178,15 @@ class DataTable:
agg_values.append(row[agg_index])
if not first and agg_values:
result_data.append(this_values + (aggregate_func(agg_values),))
result_data.append((*this_values, aggregate_func(agg_values)))
return DataTable(
[self.column_names[i] for i in gb_indices] + [agg_column],
result_data)
def join(self, column, other_column, other_table, outer=False):
"""Return a tabled joining this and the C{other_table} on C{column}.
def join(self, column: str, other_column: str, other_table: DataTable,
outer: bool = False) -> DataTable:
"""Return a table joining this and the C{other_table} on C{column}.
The new table has the following columns:
- C{column}, titled the same as in this table.
......@@ -176,9 +195,9 @@ class DataTable:
Assumes both tables are sorted ascendingly by the column
by which they are joined.
""" # pylint:disable=too-many-locals,too-many-branches
"""
def without(indexable, idx):
def without(indexable: tuple[str, ...], idx: int) -> tuple[str, ...]:
return indexable[:idx] + indexable[idx+1:]
this_key_idx = self.column_indices[column]
......@@ -187,9 +206,9 @@ class DataTable:
this_iter = self.data.__iter__()
other_iter = other_table.data.__iter__()
result_columns = [self.column_names[this_key_idx]] + \
without(self.column_names, this_key_idx) + \
without(other_table.column_names, other_key_idx)
result_columns = tuple(self.column_names[this_key_idx]) + \
without(tuple(self.column_names), this_key_idx) + \
without(tuple(other_table.column_names), other_key_idx)
result_data = []
......@@ -225,9 +244,8 @@ class DataTable:
except StopIteration:
this_over = True
break
else:
if outer:
this_batch = [(None,) * len(self.column_names)]
elif outer:
this_batch = [(None,) * len(self.column_names)]
if run_other and not other_over:
key = other_key
......@@ -238,36 +256,35 @@ class DataTable:
except StopIteration:
other_over = True
break
else:
if outer:
other_batch = [(None,) * len(other_table.column_names)]
elif outer:
other_batch = [(None,) * len(other_table.column_names)]
for this_batch_row in this_batch:
for other_batch_row in other_batch:
result_data.append((key,)
+ without(this_batch_row, this_key_idx)
+ without(other_batch_row, other_key_idx))
result_data.append((
key,
*without(this_batch_row, this_key_idx),
*without(other_batch_row, other_key_idx)))
if outer:
if this_over and other_over:
break
else:
if this_over or other_over:
break
elif this_over or other_over:
break
return DataTable(result_columns, result_data)
def restricted(self, columns):
def restricted(self, columns: Sequence[str]) -> DataTable:
col_indices = [self.column_indices[col] for col in columns]
return DataTable(columns,
[[row[i] for i in col_indices] for row in self.data])
def column_data(self, column):
def column_data(self, column: str) -> list[tuple[Any, ...]]:
col_index = self.column_indices[column]
return [row[col_index] for row in self.data]
def write_csv(self, filelike, **kwargs):
def write_csv(self, filelike: IO[Any], **kwargs: Any) -> None:
from csv import writer
csvwriter = writer(filelike, **kwargs)
csvwriter.writerow(self.column_names)
......
from __future__ import absolute_import, print_function
from __future__ import annotations
import sys
import six
from six.moves import input
from pytools import memoize
......@@ -13,8 +13,8 @@ def make_unique_filesystem_object(stem, extension="", directory="",
:param extension: needs a leading dot.
:param directory: must not have a trailing slash.
"""
from os.path import join
import os
from os.path import join
if creator is None:
def default_creator(name):
......@@ -24,7 +24,7 @@ def make_unique_filesystem_object(stem, extension="", directory="",
i = 0
while True:
fname = join(directory, "%s-%d%s" % (stem, i, extension))
fname = join(directory, f"{stem}-{i}{extension}")
try:
return creator(fname), fname
except OSError:
......@@ -53,11 +53,11 @@ def open_unique_debug_file(stem, extension=""):
# {{{ refcount debugging ------------------------------------------------------
class RefDebugQuit(Exception):
class RefDebugQuit(Exception): # noqa: N818
pass
def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too-many-locals,too-many-branches,too-many-statements
def refdebug(obj, top_level=True, exclude=()):
from types import FrameType
def is_excluded(o):
......@@ -99,10 +99,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
else:
s = str(r)
print("%d/%d: " % (idx, len(reflist)), id(r), type(r), s)
print(f"{idx}/{len(reflist)}: ", id(r), type(r), s)
if isinstance(r, dict):
for k, v in six.iteritems(r):
for k, v in r.items():
if v is obj:
print("...referred to from key", k)
......@@ -111,7 +111,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
response = input()
if response == "d":
refdebug(r, top_level=False, exclude=exclude+[reflist])
refdebug(r, top_level=False, exclude=exclude+tuple(reflist))
print_head = True
elif response == "n":
if idx + 1 < len(reflist):
......@@ -131,7 +131,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
elif response == "r":
return
elif response == "q":
raise RefDebugQuit()
raise RefDebugQuit
else:
print("WHAT YOU SAY!!! (invalid choice)")
......@@ -143,10 +143,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
# {{{ interactive shell
def get_shell_hist_filename():
def get_shell_hist_filename() -> str:
import os
_home = os.environ.get('HOME', '/')
return os.path.join(_home, ".pytools-debug-shell-history")
return os.path.expanduser(os.path.join("~", ".pytools-debug-shell-history"))
def setup_readline():
......@@ -156,12 +156,12 @@ def setup_readline():
try:
readline.read_history_file(hist_filename)
except Exception: # pylint:disable=broad-except
# http://docs.python.org/3/howto/pyporting.html#capturing-the-currently-raised-exception # noqa: E501 pylint:disable=line-too-long
# http://docs.python.org/3/howto/pyporting.html#capturing-the-currently-raised-exception
import sys
e = sys.exc_info()[1]
from warnings import warn
warn("Error opening readline history file: %s" % e)
warn(f"Error opening readline history file: {e}", stacklevel=2)
readline.parse_and_bind("tab: complete")
......@@ -216,4 +216,27 @@ def shell(locals_=None, globals_=None):
# }}}
# {{{ estimate memory usage
def estimate_memory_usage(root, seen_ids=None):
if seen_ids is None:
seen_ids = set()
id_root = id(root)
if id_root in seen_ids:
return 0
seen_ids.add(id_root)
result = sys.getsizeof(root)
from gc import get_referents
for ref in get_referents(root):
result += estimate_memory_usage(ref, seen_ids=seen_ids)
return result
# }}}
# vim: foldmethod=marker
from __future__ import absolute_import
# Python decorator module
# by Michele Simionato
# http://www.phyast.pitt.edu/~micheles/python/
## The basic trick is to generate the source code for the decorated function
## with the right signature and to evaluate it.
## Uncomment the statement 'print >> sys.stderr, func_src' in _decorate
## to understand what is going on.
__all__ = ["decorator", "update_wrapper", "getinfo"]
import inspect
def getinfo(func):
"""
Returns an info dictionary containing:
- name (the name of the function : str)
- argnames (the names of the arguments : list)
- defaults (the values of the default arguments : tuple)
- signature (the signature : str)
- doc (the docstring : str)
- module (the module name : str)
- dict (the function __dict__ : str)
>>> def f(self, x=1, y=2, *args, **kw): pass
>>> info = getinfo(f)
>>> info["name"]
'f'
>>> info["argnames"]
['self', 'x', 'y', 'args', 'kw']
>>> info["defaults"]
(1, 2)
>>> info["signature"]
'self, x, y, *args, **kw'
"""
assert inspect.ismethod(func) or inspect.isfunction(func)
regargs, varargs, varkwargs, defaults = inspect.getargspec(func)
argnames = list(regargs)
if varargs:
argnames.append(varargs)
if varkwargs:
argnames.append(varkwargs)
signature = inspect.formatargspec(regargs, varargs, varkwargs, defaults,
formatvalue=lambda value: "")[1:-1]
return dict(name=func.__name__, argnames=argnames, signature=signature,
defaults = func.__defaults__, doc=func.__doc__,
module=func.__module__, dict=func.__dict__,
globals=func.__globals__, closure=func.__closure__)
def update_wrapper(wrapper, wrapped, create=False):
"""
An improvement over functools.update_wrapper. By default it works the
same, but if the 'create' flag is set, generates a copy of the wrapper
with the right signature and update the copy, not the original.
Moreovoer, 'wrapped' can be a dictionary with keys 'name', 'doc', 'module',
'dict', 'defaults'.
"""
if isinstance(wrapped, dict):
infodict = wrapped
else: # assume wrapped is a function
infodict = getinfo(wrapped)
assert not '_wrapper_' in infodict["argnames"], \
'"_wrapper_" is a reserved argument name!'
if create: # create a brand new wrapper with the right signature
src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict
# import sys; print >> sys.stderr, src # for debugging purposes
wrapper = eval(src, dict(_wrapper_=wrapper))
try:
wrapper.__name__ = infodict['name']
except: # Python version < 2.4
pass
wrapper.__doc__ = infodict['doc']
wrapper.__module__ = infodict['module']
wrapper.__dict__.update(infodict['dict'])
wrapper.__defaults__ = infodict['defaults']
return wrapper
# the real meat is here
def _decorator(caller, func):
if not (inspect.ismethod(func) or inspect.isfunction(func)):
# skip all the fanciness, just do what works
return lambda *args, **kwargs: caller(func, *args, **kwargs)
infodict = getinfo(func)
argnames = infodict['argnames']
assert not ('_call_' in argnames or '_func_' in argnames), \
'You cannot use _call_ or _func_ as argument names!'
src = "lambda %(signature)s: _call_(_func_, %(signature)s)" % infodict
dec_func = eval(src, dict(_func_=func, _call_=caller))
return update_wrapper(dec_func, func)
def decorator(caller, func=None):
"""
General purpose decorator factory: takes a caller function as
input and returns a decorator with the same attributes.
A caller function is any function like this::
def caller(func, *args, **kw):
# do something
return func(*args, **kw)
Here is an example of usage:
>>> @decorator
... def chatty(f, *args, **kw):
... print "Calling %r" % f.__name__
... return f(*args, **kw)
>>> chatty.__name__
'chatty'
>>> @chatty
... def f(): pass
...
>>> f()
Calling 'f'
For sake of convenience, the decorator factory can also be called with
two arguments. In this casem ``decorator(caller, func)`` is just a
shortcut for ``decorator(caller)(func)``.
"""
from warnings import warn
warn("pytools.decorator is deprecated and will be removed in pytools 12. "
"Use the 'decorator' module directly instead.",
DeprecationWarning, stacklevel=2)
if func is None: # return a decorator function
return update_wrapper(lambda f : _decorator(caller, f), caller)
else: # return a decorated function
return _decorator(caller, func)
if __name__ == "__main__":
import doctest; doctest.testmod()
####################### LEGALESE ##################################
## Redistributions of source code must retain the above copyright
## notice, this list of conditions and the following disclaimer.
## Redistributions in bytecode form must reproduce the above copyright
## notice, this list of conditions and the following disclaimer in
## the documentation and/or other materials provided with the
## distribution.
## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
## DAMAGE.
from __future__ import annotations
__copyright__ = """
Copyright (C) 2009-2013 Andreas Kloeckner
Copyright (C) 2020 Matt Wala
Copyright (C) 2020 James Stevens
Copyright (C) 2024 Addison Alvey-Blanco
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
Graph Algorithms
================
.. note::
These functions are mostly geared towards directed graphs (digraphs).
.. autofunction:: reverse_graph
.. autofunction:: a_star
.. autofunction:: compute_sccs
.. autoexception:: CycleError
.. autofunction:: compute_topological_order
.. autofunction:: compute_transitive_closure
.. autofunction:: contains_cycle
.. autofunction:: compute_induced_subgraph
.. autofunction:: as_graphviz_dot
.. autofunction:: validate_graph
.. autofunction:: is_connected
.. autofunction:: undirected_graph_from_edges
.. autofunction:: get_reachable_nodes
Type Variables Used
-------------------
.. class:: _SupportsLT
A :class:`~typing.Protocol` for `__lt__` support.
.. class:: NodeT
Type of a graph node, can be any hashable type.
.. class:: GraphT
A :class:`collections.abc.Mapping` representing a directed
graph. The mapping contains one key representing each node in the
graph, and this key maps to a :class:`collections.abc.Collection` of its
successor nodes. Note that most functions expect that every graph node
is included as a key in the graph.
"""
from collections.abc import (
Callable,
Collection,
Hashable,
Iterable,
Iterator,
Mapping,
MutableSet,
)
from dataclasses import dataclass
from typing import (
Any,
Generic,
Protocol,
TypeAlias,
TypeVar,
)
NodeT = TypeVar("NodeT", bound=Hashable)
GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]]
# {{{ reverse_graph
def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]:
"""
Reverses a graph *graph*.
:returns: A :class:`dict` representing *graph* with edges reversed.
"""
result: dict[NodeT, set[NodeT]] = {}
for node_key, successor_nodes in graph.items():
# Make sure every node is in the result even if it has no successors
result.setdefault(node_key, set())
for successor in successor_nodes:
result.setdefault(successor, set()).add(node_key)
return {k: frozenset(v) for k, v in result.items()}
# }}}
# {{{ a_star
def a_star(
initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT[NodeT],
estimate_remaining_cost: Callable[[NodeT], float] | None = None,
get_step_cost: Callable[[Any, NodeT], float] = lambda x, y: 1
) -> list[NodeT]:
"""
With the default cost and heuristic, this amounts to Dijkstra's algorithm.
"""
from heapq import heappop, heappush
if estimate_remaining_cost is None:
def estimate_remaining_cost(x: NodeT) -> float:
if x != goal_state:
return 1
return 0
class AStarNode:
__slots__ = ["parent", "path_cost", "state"]
def __init__(self, state: NodeT, parent: Any, path_cost: float) -> None:
self.state = state
self.parent = parent
self.path_cost = path_cost
inf = float("inf")
init_remcost = estimate_remaining_cost(initial_state)
assert init_remcost != inf
queue = [(init_remcost, AStarNode(initial_state, parent=None, path_cost=0))]
visited_states = set()
while queue:
_, top = heappop(queue)
visited_states.add(top.state)
if top.state == goal_state:
result = []
it: AStarNode | None = top
while it is not None:
result.append(it.state)
it = it.parent
return result[::-1]
for state in neighbor_map[top.state]:
if state in visited_states:
continue
remaining_cost = estimate_remaining_cost(state)
if remaining_cost == inf:
continue
step_cost = get_step_cost(top, state)
estimated_path_cost = top.path_cost+step_cost+remaining_cost
heappush(queue,
(estimated_path_cost,
AStarNode(state, top, path_cost=top.path_cost + step_cost)))
raise RuntimeError("no solution")
# }}}
# {{{ compute SCCs with Tarjan's algorithm
def compute_sccs(graph: GraphT[NodeT]) -> list[list[NodeT]]:
to_search = set(graph.keys())
visit_order: dict[NodeT, int] = {}
scc_root = {}
sccs = []
while to_search:
top = next(iter(to_search))
call_stack: list[tuple[NodeT, Iterator[NodeT], NodeT | None]] = (
[(top, iter(graph[top]), None)])
visit_stack = []
visiting = set()
scc: list[NodeT] = []
while call_stack:
top, children, last_popped_child = call_stack.pop()
if top not in visiting:
# Unvisited: mark as visited, initialize SCC root.
count = len(visit_order)
visit_stack.append(top)
visit_order[top] = count
scc_root[top] = count
visiting.add(top)
to_search.discard(top)
# Returned from a recursion, update SCC.
if last_popped_child is not None:
scc_root[top] = min(
scc_root[top],
scc_root[last_popped_child])
for child in children:
if child not in visit_order:
# Recurse.
call_stack.append((top, children, child))
call_stack.append((child, iter(graph[child]), None))
break
if child in visiting:
scc_root[top] = min(
scc_root[top],
visit_order[child])
else:
if scc_root[top] == visit_order[top]:
scc = []
while visit_stack[-1] != top:
scc.append(visit_stack.pop())
scc.append(visit_stack.pop())
for item in scc:
visiting.remove(item)
sccs.append(scc)
return sccs
# }}}
# {{{ compute topological order
class CycleError(Exception):
"""
Raised when a topological ordering cannot be computed due to a cycle.
:attr node: Node in a directed graph that is part of a cycle.
"""
def __init__(self, node: NodeT) -> None:
self.node = node
class _SupportsLT(Protocol):
def __lt__(self, other: Any) -> bool:
...
@dataclass(frozen=True)
class _HeapEntry(Generic[NodeT]):
"""
Helper class to compare associated keys while comparing the elements in
heap operations.
Only needs to define :func:`pytools.graph.__lt__` according to
<https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>.
"""
node: NodeT
key: _SupportsLT
def __lt__(self, other: _HeapEntry[NodeT]) -> bool:
return self.key < other.key
def compute_topological_order(
graph: GraphT[NodeT],
key: Callable[[NodeT], _SupportsLT] | None = None,
) -> list[NodeT]:
"""Compute a topological order of nodes in a directed graph.
:arg key: A custom key function may be supplied to determine the order in
break-even cases. Expects a function of one argument that is used to
extract a comparison key from each node of the *graph*.
:returns: A :class:`list` representing a valid topological ordering of the
nodes in the directed graph.
.. note::
* Requires the keys of the mapping *graph* to be hashable.
* Implements `Kahn's algorithm <https://w.wiki/YDy>`__.
.. versionadded:: 2020.2
"""
# all nodes have the same keys when not provided
keyfunc = key if key is not None else (lambda x: 0)
from heapq import heapify, heappop, heappush
order = []
# {{{ compute nodes_to_num_predecessors
nodes_to_num_predecessors = dict.fromkeys(graph, 0)
for node in graph:
for child in graph[node]:
nodes_to_num_predecessors[child] = (
nodes_to_num_predecessors.get(child, 0) + 1)
# }}}
total_num_nodes = len(nodes_to_num_predecessors)
# heap: list of instances of HeapEntry(n) where 'n' is a node in
# 'graph' with no predecessor. Nodes with no predecessors are the
# schedulable candidates.
heap = [_HeapEntry(n, keyfunc(n))
for n, num_preds in nodes_to_num_predecessors.items()
if num_preds == 0]
heapify(heap)
while heap:
# pick the node with least key
node_to_be_scheduled = heappop(heap).node
order.append(node_to_be_scheduled)
# discard 'node_to_be_scheduled' from the predecessors of its
# successors since it's been scheduled
for child in graph.get(node_to_be_scheduled, ()):
nodes_to_num_predecessors[child] -= 1
if nodes_to_num_predecessors[child] == 0:
heappush(heap, _HeapEntry(child, keyfunc(child)))
if len(order) != total_num_nodes:
# any node which has a predecessor left is a part of a cycle
raise CycleError(next(iter(n for n, num_preds in
nodes_to_num_predecessors.items() if num_preds != 0)))
return order
# }}}
# {{{ compute transitive closure
def compute_transitive_closure(
graph: Mapping[NodeT, MutableSet[NodeT]]) -> GraphT[NodeT]:
"""Compute the transitive closure of a directed graph using Warshall's
algorithm.
:arg graph: A :class:`collections.abc.Mapping` representing a directed
graph. The mapping contains one key representing each node in the
graph, and this key maps to a :class:`collections.abc.MutableSet` of
nodes that are connected to the node by outgoing edges. This graph may
contain cycles. This object must be picklable. Every graph node must
be included as a key in the graph.
:returns: The transitive closure of the graph, represented using the same
data type.
.. versionadded:: 2020.2
"""
# Warshall's algorithm
from copy import deepcopy
closure = deepcopy(graph)
# (assumes all graph nodes are included in keys)
for k in graph.keys():
for n1 in graph.keys():
for n2 in graph.keys():
if k in closure[n1] and n2 in closure[k]:
closure[n1].add(n2)
return closure
# }}}
# {{{ check for cycle
def contains_cycle(graph: GraphT[NodeT]) -> bool:
"""Determine whether a graph contains a cycle.
:returns: A :class:`bool` indicating whether the graph contains a cycle.
.. versionadded:: 2020.2
"""
try:
compute_topological_order(graph)
return False
except CycleError:
return True
# }}}
# {{{ compute induced subgraph
def compute_induced_subgraph(graph: Mapping[NodeT, set[NodeT]],
subgraph_nodes: set[NodeT]) -> GraphT[NodeT]:
"""Compute the induced subgraph formed by a subset of the vertices in a
graph.
:arg graph: A :class:`collections.abc.Mapping` representing a directed
graph. The mapping contains one key representing each node in the
graph, and this key maps to a :class:`collections.abc.Set` of nodes
that are connected to the node by outgoing edges.
:arg subgraph_nodes: A :class:`collections.abc.Set` containing a subset of
the graph nodes in the graph.
:returns: A :class:`dict` representing the induced subgraph formed by
the subset of the vertices included in `subgraph_nodes`.
.. versionadded:: 2020.2
"""
new_graph = {}
for node, children in graph.items():
if node in subgraph_nodes:
new_graph[node] = children & subgraph_nodes
return new_graph
# }}}
# {{{ as_graphviz_dot
def as_graphviz_dot(graph: GraphT[NodeT],
node_labels: Callable[[NodeT], str] | None = None,
edge_labels: Callable[[NodeT, NodeT], str] | None = None,
) -> str:
"""
Create a visualization of the graph *graph* in the
`dot <http://graphviz.org/>`__ language.
:arg node_labels: An optional function that returns node labels
for each node.
:arg edge_labels: An optional function that returns edge labels
for each pair of nodes.
:returns: A string in the `dot <http://graphviz.org/>`__ language.
"""
from pytools import UniqueNameGenerator
id_gen = UniqueNameGenerator(forced_prefix="mynode")
from pytools.graphviz import dot_escape
if node_labels is None:
def node_labels(x: NodeT) -> str:
return str(x)
if edge_labels is None:
def edge_labels(x: NodeT, y: NodeT) -> str:
return ""
node_to_id = {}
for node, targets in graph.items():
if node not in node_to_id:
node_to_id[node] = id_gen()
for t in targets:
if t not in node_to_id:
node_to_id[t] = id_gen()
# Add nodes
content = "\n".join(
[f'{node_to_id[node]} [label="{dot_escape(node_labels(node))}"];'
for node in node_to_id])
content += "\n"
# Add edges
content += "\n".join(
[f"{node_to_id[node]} -> {node_to_id[t]} "
f'[label="{dot_escape(edge_labels(node, t))}"];'
for (node, targets) in graph.items()
for t in targets])
return f"digraph mygraph {{\n{content}\n}}\n"
# }}}
# {{{ validate graph
def validate_graph(graph: GraphT[NodeT]) -> None:
"""
Validates that all successor nodes of each node in *graph* are keys in
*graph* itself. Raises a :class:`ValueError` if not.
"""
seen_nodes: set[NodeT] = set()
for children in graph.values():
seen_nodes.update(children)
if not seen_nodes <= graph.keys():
raise ValueError(
f"invalid graph, missing keys: {seen_nodes-graph.keys()}")
# }}}
# {{{ is_connected
def is_connected(graph: GraphT[NodeT]) -> bool:
"""
Returns whether all nodes in *graph* are connected, ignoring
the edge direction.
:returns: A :class:`bool` indicating whether the graph is connected.
"""
if not graph:
# https://cs.stackexchange.com/questions/52815/is-a-graph-of-zero-nodes-vertices-connected
return True
visited = set()
undirected_graph = {node: set(children) for node, children in graph.items()}
for node, children in graph.items():
for child in children:
undirected_graph[child].add(node)
def dfs(node: NodeT) -> None:
visited.add(node)
for child in undirected_graph[node]:
if child not in visited:
dfs(child)
dfs(next(iter(graph.keys())))
return visited == graph.keys()
# }}}
def undirected_graph_from_edges(
edges: Iterable[tuple[NodeT, NodeT]],
) -> GraphT[NodeT]:
"""
Constructs an undirected graph using *edges*.
:arg edges: An :class:`Iterable` of pairs of related :class:`NodeT` s.
:returns: A :class:`GraphT` that is the undirected graph.
"""
undirected_graph: dict[NodeT, set[NodeT]] = {}
for lhs, rhs in edges:
if lhs == rhs:
raise TypeError("Found loop in edges,"
f" LHS, RHS = {lhs}")
undirected_graph.setdefault(lhs, set()).add(rhs)
undirected_graph.setdefault(rhs, set()).add(lhs)
return undirected_graph
def get_reachable_nodes(
undirected_graph: GraphT[NodeT],
source_node: NodeT,
exclude_nodes: Collection[NodeT] | None = None) -> frozenset[NodeT]:
"""
Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
reachable from *source_node*.
If any node from *exclude_nodes* lies on a path between *source_node* and
some other node :math:`u` in *undirected_graph* and there are no other
viable paths, then :math:`u` is considered not reachable from *source_node*.
In the case where *source_node* is in *exclude_nodes*, then no node is
reachable from *source_node*, so an empty :class:`frozenset` is returned.
"""
if exclude_nodes is not None and source_node in exclude_nodes:
return frozenset()
nodes_visited: set[NodeT] = set()
reachable_nodes: set[NodeT] = set()
nodes_to_visit = {source_node}
if exclude_nodes is None:
exclude_nodes = set()
while nodes_to_visit:
current_node = nodes_to_visit.pop()
nodes_visited.add(current_node)
reachable_nodes.add(current_node)
neighbors = undirected_graph[current_node]
nodes_to_visit.update({
node for node in neighbors
if node not in nodes_visited and node not in exclude_nodes
})
return frozenset(reachable_nodes)
# vim: foldmethod=marker
from __future__ import annotations
__copyright__ = """
Copyright (C) 2013 Andreas Kloeckner
Copyright (C) 2014 Matt Wala
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
Dot helper functions
====================
.. autofunction:: dot_escape
.. autofunction:: show_dot
"""
import html
import logging
import os
logger = logging.getLogger(__name__)
# {{{ graphviz / dot interactive show
def dot_escape(s: str) -> str:
"""
Escape the string *s* for compatibility with the
`dot <http://graphviz.org/>`__ language, particularly
backslashes and HTML tags.
:arg s: The input string to escape.
:returns: *s* with special characters escaped.
"""
# "\" and HTML are significant in graphviz.
return html.escape(s.replace("\\", "\\\\"))
def show_dot(dot_code: str, output_to: str | None = None) -> str | None:
"""
Visualize the graph represented by *dot_code*.
:arg dot_code: An instance of :class:`str` in the `dot <http://graphviz.org/>`__
language to visualize.
:arg output_to: An instance of :class:`str` that can be one of:
- ``"xwindow"`` to visualize the graph as an
`X window <https://en.wikipedia.org/wiki/X_Window_System>`_.
- ``"browser"`` to visualize the graph as an SVG file in the
system's default web-browser.
- ``"svg"`` to store the dot code as an SVG file on the file system.
Returns the path to the generated SVG file.
Defaults to ``"xwindow"`` if X11 support is present, otherwise defaults
to ``"browser"``.
:returns: Depends on *output_to*. If ``"svg"``, returns the path to the
generated SVG file, otherwise returns ``None``.
"""
import subprocess
from tempfile import mkdtemp
temp_dir = mkdtemp(prefix="tmp_pytools_dot")
dot_file_name = "code.dot"
from os.path import join
with open(join(temp_dir, dot_file_name), "w") as dotf:
dotf.write(dot_code)
# {{{ preprocess 'output_to'
if output_to is None:
with subprocess.Popen(["dot", "-T?"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
) as proc:
assert proc.stderr, ("Could not execute the 'dot' program. "
"Please install the 'graphviz' package and "
"make sure it is in your $PATH.")
supported_formats = proc.stderr.read().decode()
if " x11 " in supported_formats and "DISPLAY" in os.environ:
output_to = "xwindow"
else:
output_to = "browser"
# }}}
if output_to == "xwindow":
subprocess.check_call(["dot", "-Tx11", dot_file_name], cwd=temp_dir)
elif output_to in ["browser", "svg"]:
svg_file_name = "code.svg"
subprocess.check_call(["dot", "-Tsvg", "-o", svg_file_name, dot_file_name],
cwd=temp_dir)
full_svg_file_name = join(temp_dir, svg_file_name)
logger.info("show_dot: svg written to '%s'", full_svg_file_name)
if output_to == "svg":
return full_svg_file_name
assert output_to == "browser"
from webbrowser import open as browser_open
browser_open("file://" + full_svg_file_name)
else:
raise ValueError("`output_to` can be one of 'xwindow', 'browser', or 'svg',"
f" got '{output_to}'")
return None
# }}}
# vim: foldmethod=marker
"""Backport of importlib.import_module from 3.x.
Downloaded from: https://github.com/sprintly/importlib
This code is based in the implementation of importlib.import_module()
in Python 2.7. The license is below.
========================================================================
1. This LICENSE AGREEMENT is between the Python Software Foundation
("PSF"), and the Individual or Organization ("Licensee") accessing and
otherwise using this software ("Python") in source or binary form and
its associated documentation.
2. Subject to the terms and conditions of this License Agreement, PSF
hereby grants Licensee a nonexclusive, royalty-free, world-wide
license to reproduce, analyze, test, perform and/or display publicly,
prepare derivative works, distribute, and otherwise use Python
alone or in any derivative version, provided, however, that PSF's
License Agreement and PSF's notice of copyright, i.e., "Copyright (c)
2001, 2002, 2003, 2004, 2005, 2006 Python Software Foundation; All Rights
Reserved" are retained in Python alone or in any derivative version
prepared by Licensee.
3. In the event Licensee prepares a derivative work that is based on
or incorporates Python or any part thereof, and wants to make
the derivative work available to others as provided herein, then
Licensee hereby agrees to include in any such work a brief summary of
the changes made to Python.
4. PSF is making Python available to Licensee on an "AS IS"
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
6. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
7. Nothing in this License Agreement shall be deemed to create any
relationship of agency, partnership, or joint venture between PSF and
Licensee. This License Agreement does not grant permission to use PSF
trademarks or trade name in a trademark sense to endorse or promote
products or services of Licensee, or any third party.
8. By copying, installing or otherwise using Python, Licensee
agrees to be bound by the terms and conditions of this License
Agreement.
"""
# While not critical (and in no way guaranteed!), it would be nice to keep this
# code compatible with Python 2.3.
import sys
import six
def _resolve_name(name, package, level):
"""Return the absolute name of the module to be imported."""
if not hasattr(package, 'rindex'):
raise ValueError("'package' not set to a string")
dot = len(package)
for _ in six.moves.xrange(level, 1, -1):
try:
dot = package.rindex('.', 0, dot)
except ValueError:
raise ValueError("attempted relative import beyond top-level "
"package")
return "%s.%s" % (package[:dot], name)
def import_module(name, package=None):
"""Import a module.
The 'package' argument is required when performing a relative import. It
specifies the package to use as the anchor point from which to resolve the
relative import to an absolute import.
"""
if name.startswith('.'):
if not package:
raise TypeError("relative imports require the 'package' argument")
level = 0
for character in name:
if character != '.':
break
level += 1
name = _resolve_name(name[level:], package, level)
__import__(name)
return sys.modules[name]
from __future__ import absolute_import
from __future__ import print_function
from __future__ import annotations
import re
import six
class RuleError(RuntimeError):
......@@ -20,8 +19,8 @@ class InvalidTokenError(RuntimeError):
self.index = str_index
def __str__(self):
return "at index %d: ...%s..." % \
(self.index, self.string[self.index:self.index+20])
return "at index {}: ...{}...".format(
self.index, self.string[self.index:self.index+20])
class ParseError(RuntimeError):
......@@ -33,20 +32,19 @@ class ParseError(RuntimeError):
def __str__(self):
if self.Token is None:
return "%s at end of input" % self.message
else:
return "%s at index %d: ...%s..." % \
(self.message, self.Token[2],
self.string[self.Token[2]:self.Token[2]+20])
return f"{self.message} at end of input"
return "{} at index {}: ...{}...".format(
self.message, self.Token[2],
self.string[self.Token[2]:self.Token[2]+20])
class RE(object):
def __init__(self, s, flags=0):
class RE:
def __init__(self, s: str, flags: int = 0) -> None:
self.Content = s
self.RE = re.compile(s, flags)
def __repr__(self):
return "RE(%s)" % self.Content
def __repr__(self) -> str:
return f"RE({self.Content})"
def _matches_rule(rule, s, start, rule_dict, debug=False):
......@@ -74,10 +72,10 @@ def _matches_rule(rule, s, start, rule_dict, debug=False):
return my_match_length, None
return 0, None
elif isinstance(rule, six.string_types):
if isinstance(rule, str):
return _matches_rule(rule_dict[rule], s, start, rule_dict, debug)
elif isinstance(rule, RE):
if isinstance(rule, RE):
match_obj = rule.RE.match(s, start)
if match_obj:
return match_obj.end()-start, match_obj
......@@ -105,7 +103,7 @@ def lex(lex_table, s, debug=False, match_objects=False):
return result
class LexIterator(object):
class LexIterator:
def __init__(self, lexed, raw_str, lex_index=0):
self.lexed = lexed
self.raw_string = raw_str
......@@ -148,18 +146,16 @@ class LexIterator(object):
def raise_parse_error(self, msg):
if self.is_at_end():
raise ParseError(msg, self.raw_string, None)
else:
raise ParseError(msg, self.raw_string, self.lexed[self.index])
raise ParseError(msg, self.raw_string, self.lexed[self.index])
def expected(self, what_expected):
if self.is_at_end():
self.raise_parse_error(
"%s expected, end of input found instead" %
what_expected)
f"{what_expected} expected, end of input found instead")
else:
self.raise_parse_error(
"%s expected, %s found instead" %
(what_expected, str(self.next_tag())))
f"{what_expected} expected, {self.next_tag()} found instead")
def expect_not_end(self):
if self.is_at_end():
......
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import logging
import six
from six.moves import range
from six.moves import zip
logger = logging.getLogger(__name__)
# {{{ timing function
def time():
"""Return elapsed CPU time, as a float, in seconds."""
import os
time_opt = os.environ.get("PYTOOLS_LOG_TIME") or "wall"
if time_opt == "wall":
from time import time
return time()
elif time_opt == "rusage":
from resource import getrusage, RUSAGE_SELF
return getrusage(RUSAGE_SELF).ru_utime
else:
raise RuntimeError("invalid timing method '%s'" % time_opt)
# }}}
# {{{ abstract logging interface
class LogQuantity(object):
"""A source of loggable scalars."""
sort_weight = 0
def __init__(self, name, unit=None, description=None):
self.name = name
self.unit = unit
self.description = description
@property
def default_aggregator(self):
return None
def tick(self):
"""Perform updates required at every :class:`LogManager` tick."""
pass
def __call__(self):
"""Return the current value of the diagnostic represented by this
:class:`LogQuantity` or None if no value is available.
This is only called if the invocation interval calls for it.
"""
raise NotImplementedError
class PostLogQuantity(LogQuantity):
"""A source of loggable scalars."""
sort_weight = 0
def prepare_for_tick(self):
pass
class MultiLogQuantity(object):
"""A source of multiple loggable scalars."""
sort_weight = 0
def __init__(self, names, units=None, descriptions=None):
self.names = names
if units is None:
units = len(names) * [None]
self.units = units
if descriptions is None:
descriptions = len(names) * [None]
self.descriptions = descriptions
@property
def default_aggregators(self):
return [None] * len(self.names)
def tick(self):
"""Perform updates required at every :class:`LogManager` tick."""
pass
def __call__(self):
"""Return an iterable of the current values of the diagnostic represented
by this :class:`MultiLogQuantity`.
This is only called if the invocation interval calls for it.
"""
raise NotImplementedError
class MultiPostLogQuantity(MultiLogQuantity, PostLogQuantity):
pass
class DtConsumer(object):
def __init__(self, dt):
self.dt = dt
def set_dt(self, dt):
self.dt = dt
class TimeTracker(DtConsumer):
def __init__(self, dt):
DtConsumer.__init__(self, dt)
self.t = 0
def tick(self):
self.t += self.dt
class SimulationLogQuantity(PostLogQuantity, DtConsumer):
"""A source of loggable scalars that needs to know the simulation timestep."""
def __init__(self, dt, name, unit=None, description=None):
PostLogQuantity.__init__(self, name, unit, description)
DtConsumer.__init__(self, dt)
class PushLogQuantity(LogQuantity):
def __init__(self, name, unit=None, description=None):
LogQuantity.__init__(self, name, unit, description)
self.value = None
def push_value(self, value):
if self.value is not None:
raise RuntimeError("can't push two values per cycle")
self.value = value
def __call__(self):
v = self.value
self.value = None
return v
class CallableLogQuantityAdapter(LogQuantity):
"""Adapt a 0-ary callable as a :class:`LogQuantity`."""
def __init__(self, callable, name, unit=None, description=None):
self.callable = callable
LogQuantity.__init__(self, name, unit, description)
def __call__(self):
return self.callable()
# }}}
# {{{ manager functionality
class _GatherDescriptor(object):
def __init__(self, quantity, interval):
self.quantity = quantity
self.interval = interval
class _QuantityData(object):
def __init__(self, unit, description, default_aggregator):
self.unit = unit
self.description = description
self.default_aggregator = default_aggregator
def _join_by_first_of_tuple(list_of_iterables):
loi = [i.__iter__() for i in list_of_iterables]
if not loi:
return
key_vals = [next(iter) for iter in loi]
keys = [kv[0] for kv in key_vals]
values = [kv[1] for kv in key_vals]
target_key = max(keys)
force_advance = False
i = 0
while True:
while keys[i] < target_key or force_advance:
try:
new_key, new_value = next(loi[i])
except StopIteration:
return
assert keys[i] < new_key
keys[i] = new_key
values[i] = new_value
if new_key > target_key:
target_key = new_key
force_advance = False
i += 1
if i >= len(loi):
i = 0
if min(keys) == target_key:
yield target_key, values[:]
force_advance = True
def _get_unique_id():
try:
from uuid import uuid1
except ImportError:
try:
import hashlib
checksum = hashlib.md5()
except ImportError:
# for Python << 2.5
import md5
checksum = md5.new()
from random import Random
rng = Random()
rng.seed()
for i in range(20):
checksum.update(str(rng.randrange(1 << 30)).encode('utf-32'))
return checksum.hexdigest()
else:
return uuid1().hex
def _get_random_suffix(n):
characters = (
[chr(65+i) for i in range(26)]
+ [chr(97+i) for i in range(26)]
+ [chr(48+i) for i in range(10)])
from random import choice
return "".join(choice(characters) for i in range(n))
def _set_up_schema(db_conn):
# initialize new database
db_conn.execute("""
create table quantities (
name text,
unit text,
description text,
default_aggregator blob)""")
db_conn.execute("""
create table constants (
name text,
value blob)""")
db_conn.execute("""
create table warnings (
rank integer,
step integer,
message text,
category text,
filename text,
lineno integer
)""")
schema_version = 2
return schema_version
class LogManager(object):
"""A parallel-capable diagnostic time-series logging facility.
It is meant to log data from a computation, with certain log
quantities available before a cycle, and certain other ones
afterwards. A timeline of invocations looks as follows::
tick_before()
compute...
tick()
tick_after()
tick_before()
compute...
tick_after()
...
In a time-dependent simulation, each group of :meth:`tick_before`
:meth:`tick_after` calls captures data for a single time state,
namely that in which the data may have been *before* the "compute"
step. However, some data (such as the length of the timestep taken
in a time-adpative method) may only be available *after* the completion
of the "compute..." stage, which is why :meth:`tick_after` exists.
A :class:`LogManager` logs any number of named time series of floats to
a file. Non-time-series data, in the form of constants, is also
supported and saved.
If MPI parallelism is used, the "head rank" below always refers to
rank 0.
Command line tools called :command:`runalyzer` and :command:`logtool`
(deprecated) are available for looking at the data in a saved log.
"""
def __init__(self, filename=None, mode="r", mpi_comm=None, capture_warnings=True,
commit_interval=90):
"""Initialize this log manager instance.
:param filename: If given, the filename to which this log is bound.
If this database exists, the current state is loaded from it.
:param mode: One of "w", "r" for write, read. "w" assumes that the
database is initially empty. May also be "wu" to indicate that
a unique filename should be chosen automatically.
:arg mpi_comm: A :mod:`mpi4py.MPI.Comm`. If given, logs are
periodically synchronized to the head node, which then writes them
out to disk.
:param capture_warnings: Tap the Python warnings facility and save warnings
to the log file.
:param commit_interval: actually perform a commit only every N times a commit
is requested.
"""
assert isinstance(mode, six.string_types), "mode must be a string"
assert mode in ["w", "r", "wu"], "invalid mode"
self.quantity_data = {}
self.last_values = {}
self.before_gather_descriptors = []
self.after_gather_descriptors = []
self.tick_count = 0
self.commit_interval = commit_interval
self.commit_countdown = commit_interval
self.constants = {}
self.last_save_time = time()
# self-timing
self.start_time = time()
self.t_log = 0
# parallel support
self.head_rank = 0
self.mpi_comm = mpi_comm
self.is_parallel = mpi_comm is not None
if mpi_comm is None:
self.rank = 0
else:
self.rank = mpi_comm.rank
self.head_rank = 0
# watch stuff
self.watches = []
self.next_watch_tick = 1
self.have_nonlocal_watches = False
# database binding
try:
import sqlite3 as sqlite
except ImportError:
try:
from pysqlite2 import dbapi2 as sqlite
except ImportError:
raise ImportError("could not find a usable version of sqlite.")
if filename is None:
filename = ":memory:"
else:
if self.is_parallel:
filename += "-rank%d" % self.rank
while True:
suffix = ""
if mode == "wu":
suffix = "-"+_get_random_suffix(6)
self.db_conn = sqlite.connect(filename+suffix, timeout=30)
self.mode = mode
try:
self.db_conn.execute("select * from quantities;")
except sqlite.OperationalError:
# we're building a new database
if mode == "r":
raise RuntimeError("Log database '%s' not found" % filename)
self.schema_version = _set_up_schema(self.db_conn)
self.set_constant("schema_version", self.schema_version)
self.set_constant("is_parallel", self.is_parallel)
# set globally unique run_id
if self.is_parallel:
self.set_constant("unique_run_id",
self.mpi_comm.bcast(_get_unique_id(),
root=self.head_rank))
else:
self.set_constant("unique_run_id", _get_unique_id())
if self.is_parallel:
self.set_constant("rank_count", self.mpi_comm.Get_size())
else:
self.set_constant("rank_count", 1)
else:
# we've opened an existing database
if mode == "w":
raise RuntimeError("Log database '%s' already exists" % filename)
elif mode == "wu":
# try again with a new suffix
continue
self._load()
break
self.old_showwarning = None
if capture_warnings:
self.capture_warnings(True)
def capture_warnings(self, enable=True):
def _showwarning(message, category, filename, lineno, file=None, line=None):
try:
self.old_showwarning(message, category, filename, lineno, file, line)
except TypeError:
# cater to Python 2.5 and earlier
self.old_showwarning(message, category, filename, lineno)
if self.schema_version >= 1 and self.mode == "w":
if self.schema_version >= 2:
self.db_conn.execute("insert into warnings values (?,?,?,?,?,?)",
(self.rank, self.tick_count, str(message), str(category),
filename, lineno))
else:
self.db_conn.execute("insert into warnings values (?,?,?,?,?)",
(self.tick_count, str(message), str(category),
filename, lineno))
import warnings
if enable:
if self.old_showwarning is None:
pass
self.old_showwarning = warnings.showwarning
warnings.showwarning = _showwarning
else:
raise RuntimeError("Warnings capture was enabled twice")
else:
if self.old_showwarning is None:
raise RuntimeError(
"Warnings capture was disabled, but never enabled")
else:
warnings.showwarning = self.old_showwarning
self.old_showwarning = None
def _load(self):
if self.mpi_comm and self.mpi_comm.rank != self.head_rank:
return
from pickle import loads
for name, value in self.db_conn.execute("select name, value from constants"):
self.constants[name] = loads(value)
self.schema_version = self.constants.get("schema_version", 0)
self.is_parallel = self.constants["is_parallel"]
for name, unit, description, def_agg in self.db_conn.execute(
"select name, unit, description, default_aggregator "
"from quantities"):
self.quantity_data[name] = _QuantityData(
unit, description, loads(def_agg))
def close(self):
if self.old_showwarning is not None:
self.capture_warnings(False)
self.save()
self.db_conn.close()
def get_table(self, q_name):
if q_name not in self.quantity_data:
raise KeyError("invalid quantity name '%s'" % q_name)
from pytools.datatable import DataTable
result = DataTable(["step", "rank", "value"])
for row in self.db_conn.execute(
"select step, rank, value from %s" % q_name):
result.insert_row(row)
return result
def get_warnings(self):
columns = ["step", "message", "category", "filename", "lineno"]
if self.schema_version >= 2:
columns.insert(0, "rank")
from pytools.datatable import DataTable
result = DataTable(columns)
for row in self.db_conn.execute(
"select %s from warnings" % (", ".join(columns))):
result.insert_row(row)
return result
def add_watches(self, watches):
"""Add quantities that are printed after every time step."""
from pytools import Record
class WatchInfo(Record):
pass
for watch in watches:
if isinstance(watch, tuple):
display, expr = watch
else:
display = watch
expr = watch
parsed = self._parse_expr(expr)
parsed, dep_data = self._get_expr_dep_data(parsed)
from pytools import any
self.have_nonlocal_watches = self.have_nonlocal_watches or \
any(dd.nonlocal_agg for dd in dep_data)
from pymbolic import compile
compiled = compile(parsed, [dd.varname for dd in dep_data])
watch_info = WatchInfo(display=display, parsed=parsed, dep_data=dep_data,
compiled=compiled)
self.watches.append(watch_info)
def set_constant(self, name, value):
"""Make a named, constant value available in the log."""
existed = name in self.constants
self.constants[name] = value
from pickle import dumps
value = bytes(dumps(value))
if existed:
self.db_conn.execute("update constants set value = ? where name = ?",
(value, name))
else:
self.db_conn.execute("insert into constants values (?,?)",
(name, value))
self._commit()
def _insert_datapoint(self, name, value):
if value is None:
return
self.last_values[name] = value
try:
self.db_conn.execute("insert into %s values (?,?,?)" % name,
(self.tick_count, self.rank, float(value)))
except Exception:
print("while adding datapoint for '%s':" % name)
raise
def _gather_for_descriptor(self, gd):
if self.tick_count % gd.interval == 0:
q_value = gd.quantity()
if isinstance(gd.quantity, MultiLogQuantity):
for name, value in zip(gd.quantity.names, q_value):
self._insert_datapoint(name, value)
else:
self._insert_datapoint(gd.quantity.name, q_value)
def tick(self):
"""Record data points from each added :class:`LogQuantity`.
May also checkpoint data to disk, and/or synchronize data points
to the head rank.
"""
from warnings import warn
warn("LogManager.tick() is deprecated. "
"Use LogManager.tick_{before,after}().",
DeprecationWarning)
self.tick_before()
self.tick_after()
def tick_before(self):
"""Record data points from each added :class:`LogQuantity` that
is not an instance of :class:`PostLogQuantity`. Also, invoke
:meth:`PostLogQuantity.prepare_for_tick` on :class:`PostLogQuantity`
instances.
"""
tick_start_time = time()
for gd in self.before_gather_descriptors:
self._gather_for_descriptor(gd)
for gd in self.after_gather_descriptors:
gd.quantity.prepare_for_tick()
self.t_log = time() - tick_start_time
def tick_after(self):
"""Record data points from each added :class:`LogQuantity` that
is an instance of :class:`PostLogQuantity`.
May also checkpoint data to disk.
"""
tick_start_time = time()
for gd_lst in [self.before_gather_descriptors,
self.after_gather_descriptors]:
for gd in gd_lst:
gd.quantity.tick()
for gd in self.after_gather_descriptors:
self._gather_for_descriptor(gd)
self.tick_count += 1
if tick_start_time - self.start_time > 15*60:
save_interval = 5*60
else:
save_interval = 15
if tick_start_time > self.last_save_time + save_interval:
self.save()
# print watches
if self.tick_count == self.next_watch_tick:
self._watch_tick()
self.t_log += time() - tick_start_time
def _commit(self):
self.commit_countdown -= 1
if self.commit_countdown <= 0:
self.commit_countdown = self.commit_interval
self.db_conn.commit()
def save(self):
from sqlite3 import OperationalError
try:
self.db_conn.commit()
except OperationalError as e:
from warnings import warn
warn("encountered sqlite error during commit: %s" % e)
self.last_save_time = time()
def add_quantity(self, quantity, interval=1):
"""Add an object derived from :class:`LogQuantity` to this manager."""
def add_internal(name, unit, description, def_agg):
logger.debug("add log quantity '%s'" % name)
if name in self.quantity_data:
raise RuntimeError("cannot add the same quantity '%s' twice" % name)
self.quantity_data[name] = _QuantityData(unit, description, def_agg)
from pickle import dumps
self.db_conn.execute("""insert into quantities values (?,?,?,?)""", (
name, unit, description,
bytes(dumps(def_agg))))
self.db_conn.execute("""create table %s
(step integer, rank integer, value real)""" % name)
self._commit()
gd = _GatherDescriptor(quantity, interval)
if isinstance(quantity, PostLogQuantity):
gd_list = self.after_gather_descriptors
else:
gd_list = self.before_gather_descriptors
gd_list.append(gd)
gd_list.sort(key=lambda gd: gd.quantity.sort_weight)
if isinstance(quantity, MultiLogQuantity):
for name, unit, description, def_agg in zip(
quantity.names,
quantity.units,
quantity.descriptions,
quantity.default_aggregators):
add_internal(name, unit, description, def_agg)
else:
add_internal(quantity.name,
quantity.unit, quantity.description,
quantity.default_aggregator)
def get_expr_dataset(self, expression, description=None, unit=None):
"""Prepare a time-series dataset for a given expression.
@arg expression: A C{pymbolic} expression that may involve
the time-series variables and the constants in this :class:`LogManager`.
If there is data from multiple ranks for a quantity occuring in
this expression, an aggregator may have to be specified.
@return: C{(description, unit, table)}, where C{table}
is a list of tuples C{(tick_nbr, value)}.
Aggregators are specified as follows:
- C{qty.min}, C{qty.max}, C{qty.avg}, C{qty.sum}, C{qty.norm2}
- C{qty[rank_nbr]}
- C{qty.loc}
"""
parsed = self._parse_expr(expression)
parsed, dep_data = self._get_expr_dep_data(parsed)
# aggregate table data
for dd in dep_data:
table = self.get_table(dd.name)
table.sort(["step"])
dd.table = table.aggregated(["step"], "value", dd.agg_func).data
# evaluate unit and description, if necessary
if unit is None:
from pymbolic import substitute, parse
unit_dict = dict((dd.varname, dd.qdat.unit) for dd in dep_data)
from pytools import all
if all(v is not None for v in six.itervalues(unit_dict)):
unit_dict = dict((k, parse(v)) for k, v in six.iteritems(unit_dict))
unit = substitute(parsed, unit_dict)
else:
unit = None
if description is None:
description = expression
# compile and evaluate
from pymbolic import compile
compiled = compile(parsed, [dd.varname for dd in dep_data])
data = []
for key, values in _join_by_first_of_tuple(dd.table for dd in dep_data):
try:
data.append((key, compiled(*values)))
except ZeroDivisionError:
pass
return (description, unit, data)
def get_joint_dataset(self, expressions):
"""Return a joint data set for a list of expressions.
@arg expressions: a list of either strings representing
expressions directly, or triples (descr, unit, expr).
In the former case, the description and the unit are
found automatically, if possible. In the latter case,
they are used as specified.
@return: A triple C{(descriptions, units, table)}, where
C{table} is a a list of C{[(tstep, (val_expr1, val_expr2,...)...]}.
"""
# dubs is a list of (desc, unit, table) triples as
# returned by get_expr_dataset
dubs = []
for expr in expressions:
if isinstance(expr, str):
dub = self.get_expr_dataset(expr)
else:
expr_descr, expr_unit, expr_str = expr
dub = self.get_expr_dataset(
expr_str,
description=expr_descr,
unit=expr_unit)
dubs.append(dub)
zipped_dubs = list(zip(*dubs))
zipped_dubs[2] = list(
_join_by_first_of_tuple(zipped_dubs[2]))
return zipped_dubs
def get_plot_data(self, expr_x, expr_y, min_step=None, max_step=None):
"""Generate plot-ready data.
:return: ``(data_x, descr_x, unit_x), (data_y, descr_y, unit_y)``
"""
(descr_x, descr_y), (unit_x, unit_y), data = \
self.get_joint_dataset([expr_x, expr_y])
if min_step is not None:
data = [(step, tup) for step, tup in data if min_step <= step]
if max_step is not None:
data = [(step, tup) for step, tup in data if step <= max_step]
stepless_data = [tup for step, tup in data]
if stepless_data:
data_x, data_y = list(zip(*stepless_data))
else:
data_x = []
data_y = []
return (data_x, descr_x, unit_x), \
(data_y, descr_y, unit_y)
def write_datafile(self, filename, expr_x, expr_y):
(data_x, label_x), (data_y, label_y) = self.get_plot_data(
expr_x, expr_y)
outf = open(filename, "w")
outf.write("# %s vs. %s" % (label_x, label_y))
for dx, dy in zip(data_x, data_y):
outf.write("%s\t%s\n" % (repr(dx), repr(dy)))
outf.close()
def plot_matplotlib(self, expr_x, expr_y):
from pylab import xlabel, ylabel, plot
(data_x, descr_x, unit_x), (data_y, descr_y, unit_y) = \
self.get_plot_data(expr_x, expr_y)
xlabel("%s [%s]" % (descr_x, unit_x))
ylabel("%s [%s]" % (descr_y, unit_y))
plot(data_x, data_y)
# {{{ private functionality
def _parse_expr(self, expr):
from pymbolic import parse, substitute
parsed = parse(expr)
# substitute in global constants
parsed = substitute(parsed, self.constants)
return parsed
def _get_expr_dep_data(self, parsed):
class Nth:
def __init__(self, n):
self.n = n
def __call__(self, lst):
return lst[self.n]
from pymbolic.mapper.dependency import DependencyMapper
deps = DependencyMapper(include_calls=False)(parsed)
# gather information on aggregation expressions
dep_data = []
from pymbolic.primitives import Variable, Lookup, Subscript
for dep_idx, dep in enumerate(deps):
nonlocal_agg = True
if isinstance(dep, Variable):
name = dep.name
if name == "math":
continue
agg_func = self.quantity_data[name].default_aggregator
if agg_func is None:
if self.is_parallel:
raise ValueError(
"must specify explicit aggregator for '%s'" % name)
else:
agg_func = lambda lst: lst[0]
elif isinstance(dep, Lookup):
assert isinstance(dep.aggregate, Variable)
name = dep.aggregate.name
agg_name = dep.name
if agg_name == "loc":
agg_func = Nth(self.rank)
nonlocal_agg = False
elif agg_name == "min":
agg_func = min
elif agg_name == "max":
agg_func = max
elif agg_name == "avg":
from pytools import average
agg_func = average
elif agg_name == "sum":
agg_func = sum
elif agg_name == "norm2":
from math import sqrt
agg_func = lambda iterable: sqrt(
sum(entry**2 for entry in iterable))
else:
raise ValueError("invalid rank aggregator '%s'" % agg_name)
elif isinstance(dep, Subscript):
assert isinstance(dep.aggregate, Variable)
name = dep.aggregate.name
from pymbolic import evaluate
agg_func = Nth(evaluate(dep.index))
qdat = self.quantity_data[name]
from pytools import Record
class DependencyData(Record):
pass
this_dep_data = DependencyData(name=name, qdat=qdat, agg_func=agg_func,
varname="logvar%d" % dep_idx, expr=dep,
nonlocal_agg=nonlocal_agg)
dep_data.append(this_dep_data)
# substitute in the "logvar" variable names
from pymbolic import var, substitute
parsed = substitute(parsed,
dict((dd.expr, var(dd.varname)) for dd in dep_data))
return parsed, dep_data
def _watch_tick(self):
if not self.have_nonlocal_watches and self.rank != self.head_rank:
return
data_block = dict((qname, self.last_values.get(qname, 0))
for qname in six.iterkeys(self.quantity_data))
if self.mpi_comm is not None and self.have_nonlocal_watches:
gathered_data = self.mpi_comm.gather(data_block, self.head_rank)
else:
gathered_data = [data_block]
if self.rank == self.head_rank:
values = {}
for data_block in gathered_data:
for name, value in six.iteritems(data_block):
values.setdefault(name, []).append(value)
def compute_watch_str(watch):
try:
return "%s=%g" % (watch.display, watch.compiled(
*[dd.agg_func(values[dd.name])
for dd in watch.dep_data]))
except ZeroDivisionError:
return "%s:div0" % watch.display
if self.watches:
print(" | ".join(
compute_watch_str(watch) for watch in self.watches))
ticks_per_sec = self.tick_count/max(1, time()-self.start_time)
self.next_watch_tick = self.tick_count + int(max(1, ticks_per_sec))
if self.mpi_comm is not None and self.have_nonlocal_watches:
self.next_watch_tick = self.mpi_comm.bcast(
self.next_watch_tick, self.head_rank)
# }}}
# }}}
# {{{ actual data loggers
class _SubTimer:
def __init__(self, itimer):
self.itimer = itimer
self.start_time = time()
self.elapsed = 0
def stop(self):
self.elapsed += time() - self.start_time
del self.start_time
return self
def submit(self):
self.itimer.add_time(self.elapsed)
del self.elapsed
class IntervalTimer(PostLogQuantity):
"""Records elapsed times."""
def __init__(self, name, description=None):
LogQuantity.__init__(self, name, "s", description)
self.elapsed = 0
def start_sub_timer(self):
return _SubTimer(self)
def add_time(self, t):
self.start_time = time()
self.elapsed += t
def __call__(self):
result = self.elapsed
self.elapsed = 0
return result
class LogUpdateDuration(LogQuantity):
"""Records how long the last :meth:`LogManager.tick` invocation took."""
# FIXME this is off by one tick
def __init__(self, mgr, name="t_log"):
LogQuantity.__init__(self, name, "s", "Time spent updating the log")
self.log_manager = mgr
def __call__(self):
return self.log_manager.t_log
class EventCounter(PostLogQuantity):
"""Counts events signaled by :meth:`add`."""
def __init__(self, name="interval", description=None):
PostLogQuantity.__init__(self, name, "1", description)
self.events = 0
def add(self, n=1):
self.events += n
def transfer(self, counter):
self.events += counter.pop()
def prepare_for_tick(self):
self.events = 0
def __call__(self):
result = self.events
return result
def time_and_count_function(f, timer, counter=None, increment=1):
def inner_f(*args, **kwargs):
if counter is not None:
counter.add(increment)
sub_timer = timer.start_sub_timer()
try:
return f(*args, **kwargs)
finally:
sub_timer.stop().submit()
return inner_f
class TimestepCounter(LogQuantity):
"""Counts the number of times :meth:`LogManager.tick` is called."""
def __init__(self, name="step"):
LogQuantity.__init__(self, name, "1", "Timesteps")
self.steps = 0
def __call__(self):
result = self.steps
self.steps += 1
return result
class StepToStepDuration(PostLogQuantity):
"""Records the CPU time between invocations of
:meth:`LogManager.tick_before` and
:meth:`LogManager.tick_after`.
"""
def __init__(self, name="t_2step"):
PostLogQuantity.__init__(self, name, "s", "Step-to-step duration")
self.last_start_time = None
self.last2_start_time = None
def prepare_for_tick(self):
self.last2_start_time = self.last_start_time
self.last_start_time = time()
def __call__(self):
if self.last2_start_time is None:
return None
else:
return self.last_start_time - self.last2_start_time
class TimestepDuration(PostLogQuantity):
"""Records the CPU time between the starts of time steps.
:meth:`LogManager.tick_before` and
:meth:`LogManager.tick_after`.
"""
# We would like to run last, so that if log gathering takes any
# significant time, we catch that, too. (CUDA sync-on-time-taking,
# I'm looking at you.)
sort_weight = 1000
def __init__(self, name="t_step"):
PostLogQuantity.__init__(self, name, "s", "Time step duration")
def prepare_for_tick(self):
self.last_start = time()
def __call__(self):
now = time()
result = now - self.last_start
del self.last_start
return result
class CPUTime(LogQuantity):
"""Records (monotonically increasing) CPU time."""
def __init__(self, name="t_cpu"):
LogQuantity.__init__(self, name, "s", "Wall time")
self.start = time()
def __call__(self):
return time()-self.start
class ETA(LogQuantity):
"""Records an estimate of how long the computation will still take."""
def __init__(self, total_steps, name="t_eta"):
LogQuantity.__init__(self, name, "s", "Estimated remaining duration")
self.steps = 0
self.total_steps = total_steps
self.start = time()
def __call__(self):
fraction_done = self.steps/self.total_steps
self.steps += 1
time_spent = time()-self.start
if fraction_done > 1e-9:
return time_spent/fraction_done-time_spent
else:
return 0
def add_general_quantities(mgr):
"""Add generally applicable :class:`LogQuantity` objects to C{mgr}."""
mgr.add_quantity(TimestepDuration())
mgr.add_quantity(StepToStepDuration())
mgr.add_quantity(CPUTime())
mgr.add_quantity(LogUpdateDuration(mgr))
mgr.add_quantity(TimestepCounter())
class SimulationTime(TimeTracker, LogQuantity):
"""Record (monotonically increasing) simulation time."""
def __init__(self, dt, name="t_sim", start=0):
LogQuantity.__init__(self, name, "s", "Simulation Time")
TimeTracker.__init__(self, dt)
def __call__(self):
return self.t
class Timestep(SimulationLogQuantity):
"""Record the magnitude of the simulated time step."""
def __init__(self, dt, name="dt", unit="s"):
SimulationLogQuantity.__init__(self, dt, name, unit, "Simulation Timestep")
def __call__(self):
return self.dt
def set_dt(mgr, dt):
"""Set the simulation timestep on :class:`LogManager` C{mgr} to C{dt}."""
for gd_lst in [mgr.before_gather_descriptors,
mgr.after_gather_descriptors]:
for gd in gd_lst:
if isinstance(gd.quantity, DtConsumer):
gd.quantity.set_dt(dt)
def add_simulation_quantities(mgr, dt=None):
"""Add :class:`LogQuantity` objects relating to simulation time."""
if dt is not None:
from warnings import warn
warn("Specifying dt ahead of time is a deprecated practice. "
"Use pytools.log.set_dt() instead.")
mgr.add_quantity(SimulationTime(dt))
mgr.add_quantity(Timestep(dt))
def add_run_info(mgr):
"""Add generic run metadata, such as command line, host, and time."""
import sys
mgr.set_constant("cmdline", " ".join(sys.argv))
from socket import gethostname
mgr.set_constant("machine", gethostname())
from time import localtime, strftime, time
mgr.set_constant("date", strftime("%a, %d %b %Y %H:%M:%S %Z", localtime()))
mgr.set_constant("unixtime", time())
# }}}
# vim: foldmethod=marker
from __future__ import absolute_import
from __future__ import annotations
__copyright__ = """
Copyright (C) 2009-2019 Andreas Kloeckner
Copyright (C) 2022 University of Illinois Board of Trustees
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
MPI helper functionality
========================
.. autofunction:: check_for_mpi_relaunch
.. autofunction:: run_with_mpi_ranks
.. autofunction:: pytest_raises_on_rank
"""
import contextlib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Generator
def check_for_mpi_relaunch(argv):
......@@ -17,8 +59,8 @@ def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
import sys
import os
import sys
newenv = os.environ.copy()
newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1"
......@@ -29,3 +71,24 @@ def run_with_mpi_ranks(py_script, ranks, callable_, args=(), kwargs=None):
check_call(["mpirun", "-np", str(ranks),
sys.executable, py_script, "--mpi-relaunch", callable_and_args],
env=newenv)
@contextlib.contextmanager
def pytest_raises_on_rank(
my_rank: int, fail_rank: int,
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
) -> Generator[contextlib.AbstractContextManager, None, None]:
"""
Like :func:`pytest.raises`, but only expect an exception on rank *fail_rank*.
"""
from contextlib import nullcontext
import pytest
if my_rank == fail_rank:
cm: contextlib.AbstractContextManager = pytest.raises(expected_exception)
else:
cm = nullcontext()
with cm as exc:
yield exc
"""See pytools.prefork for this module's reason for being."""
from __future__ import absolute_import
from __future__ import annotations
import mpi4py.rc # pylint:disable=import-error
mpi4py.rc.initialize = False
from mpi4py.MPI import * # noqa pylint:disable=wildcard-import,wrong-import-position
import pytools.prefork # pylint:disable=wrong-import-position
pytools.prefork.enable_prefork()
if Is_initialized(): # noqa pylint:disable=undefined-variable
# pylint: disable-next=undefined-variable
if Is_initialized(): # type: ignore[name-defined,unused-ignore] # noqa
raise RuntimeError("MPI already initialized before MPI wrapper import")
......
from __future__ import absolute_import, division
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2020 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from functools import partial, update_wrapper
from warnings import warn
import numpy as np
from pytools import my_decorator as decorator, MovedFunctionDeprecationWrapper
__doc__ = """
Handling :mod:`numpy` Object Arrays
===================================
.. autofunction:: oarray_real
.. autofunction:: oarray_imag
.. autofunction:: oarray_real_copy
.. autofunction:: oarray_imag_copy
Creation
--------
.. autofunction:: join_fields
.. autofunction:: make_obj_array
.. autofunction:: flat_obj_array
Mapping
-------
.. autofunction:: with_object_array_or_scalar
.. autofunction:: with_object_array_or_scalar_n_args
.. autofunction:: obj_array_vectorize
.. autofunction:: obj_array_vectorize_n_args
Numpy workarounds
-----------------
These functions work around a `long-standing, annoying numpy issue
<https://github.com/numpy/numpy/issues/1740>`__.
.. autofunction:: obj_array_real
.. autofunction:: obj_array_imag
.. autofunction:: obj_array_real_copy
.. autofunction:: obj_array_imag_copy
"""
def gen_len(expr):
if is_obj_array(expr):
return len(expr)
else:
return 1
def make_obj_array(res_list):
"""Create a one-dimensional object array from *res_list*.
This differs from ``numpy.array(res_list, dtype=object)``
by whether it tries to determine its shape by descending
into nested array-like objects. Consider the following example:
.. doctest::
>>> import numpy as np
>>> a = np.array([np.arange(5), np.arange(5)], dtype=object)
>>> a
array([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]], dtype=object)
>>> a.shape
(2, 5)
>>> # meanwhile:
>>> from pytools.obj_array import make_obj_array
>>> b = make_obj_array([np.arange(5), np.arange(5)])
>>> b
array([array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4])], dtype=object)
>>> b.shape
(2,)
In some settings (such as when the sub-arrays are large and/or
live on a GPU), the recursive behavior of :func:`numpy.array`
can be undesirable.
"""
result = np.empty((len(res_list),), dtype=object)
# 'result[:] = res_list' may look tempting, however:
# https://github.com/numpy/numpy/issues/16564
for idx in range(len(res_list)):
result[idx] = res_list[idx]
def gen_slice(expr, slice_):
result = expr[slice_]
if len(result) == 1:
return result[0]
else:
return result
def obj_array_to_hashable(f):
if isinstance(f, np.ndarray) and f.dtype.char == "O":
return tuple(f)
return f
def flat_obj_array(*args):
"""Return a one-dimensional flattened object array consisting of
elements obtained by 'flattening' *args* as follows:
- The first axis of any non-subclassed object arrays will be flattened
into the result.
- Instances of :class:`list` will be flattened into the result.
- Any other type will appear in the list as-is.
"""
res_list = []
for arg in args:
if isinstance(arg, list):
res_list.extend(arg)
# Only flatten genuine, non-subclassed object arrays.
elif type(arg) is np.ndarray:
res_list.extend(arg.flat)
else:
res_list.append(arg)
return make_obj_array(res_list)
def obj_array_vectorize(f, ary):
"""Apply the function *f* to all entries of the object array *ary*.
Return an object array of the same shape consisting of the return
values.
If *ary* is not an object array, return ``f(ary)``.
.. note ::
This function exists because :class:`numpy.vectorize` suffers from the same
issue described under :func:`make_obj_array`.
"""
if isinstance(ary, np.ndarray) and ary.dtype.char == "O":
result = np.empty_like(ary)
for i in np.ndindex(ary.shape):
result[i] = f(ary[i])
return result
return f(ary)
def is_obj_array(val):
try:
return isinstance(val, np.ndarray) and val.dtype == object
except AttributeError:
return False
def obj_array_vectorized(f):
wrapper = partial(obj_array_vectorize, f)
update_wrapper(wrapper, f)
return wrapper
def to_obj_array(ary):
ls = log_shape(ary)
result = np.empty(ls, dtype=object)
def rec_obj_array_vectorize(f, ary):
"""Apply the function *f* to all entries of the object array *ary*.
Return an object array of the same shape consisting of the return
values.
If the elements of *ary* are further object arrays, recurse
until non-object-arrays are found and then apply *f* to those
entries.
If *ary* is not an object array, return ``f(ary)``.
from pytools import indices_in_shape
for i in indices_in_shape(ls):
result[i] = ary[i]
.. note ::
return result
This function exists because :class:`numpy.vectorize` suffers from the same
issue described under :func:`make_obj_array`.
"""
if isinstance(ary, np.ndarray) and ary.dtype.char == "O":
result = np.empty_like(ary)
for i in np.ndindex(ary.shape):
result[i] = rec_obj_array_vectorize(f, ary[i])
return result
return f(ary)
def is_equal(a, b):
if is_obj_array(a):
return is_obj_array(b) and (a.shape == b.shape) and (a == b).all()
else:
return not is_obj_array(b) and a == b
def rec_obj_array_vectorized(f):
wrapper = partial(rec_obj_array_vectorize, f)
update_wrapper(wrapper, f)
return wrapper
# moderately deprecated
is_field_equal = is_equal
def obj_array_vectorize_n_args(f, *args):
"""Apply the function *f* elementwise to all entries of any
object arrays in *args*. All such object arrays are expected
to have the same shape (but this is not checked).
Equivalent to an appropriately-looped execution of::
result[idx] = f(obj_array_arg1[idx], arg2, obj_array_arg3[idx])
def make_obj_array(res_list):
result = np.empty((len(res_list),), dtype=object)
for i, v in enumerate(res_list):
result[i] = v
Return an object array of the same shape as the arguments consisting of the
return values of *f*.
.. note ::
This function exists because :class:`numpy.vectorize` suffers from the same
issue described under :func:`make_obj_array`.
"""
oarray_arg_indices = []
for i, arg in enumerate(args):
if isinstance(arg, np.ndarray) and arg.dtype.char == "O":
oarray_arg_indices.append(i)
if not oarray_arg_indices:
return f(*args)
leading_oa_index = oarray_arg_indices[0]
template_ary = args[leading_oa_index]
result = np.empty_like(template_ary)
new_args = list(args)
for i in np.ndindex(template_ary.shape):
for arg_i in oarray_arg_indices:
new_args[arg_i] = args[arg_i][i]
result[i] = f(*new_args)
return result
def setify_field(f):
if is_obj_array(f):
return set(f)
else:
return set([f])
def obj_array_vectorized_n_args(f):
# Unfortunately, this can't use partial(), as the callable returned by it
# will not be turned into a bound method upon attribute access.
# This may happen here, because the decorator *could* be used
# on methods, since it can "look past" the leading `self` argument.
# Only exactly function objects receive this treatment.
#
# Spec link:
# https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy
# (under "Instance Methods", quote as of Py3.9.4)
# > Also notice that this transformation only happens for user-defined functions;
# > other callable objects (and all non-callable objects) are retrieved
# > without transformation.
def wrapper(*args):
return obj_array_vectorize_n_args(f, *args)
def obj_array_to_hashable(f):
if is_obj_array(f):
return tuple(f)
else:
return f
update_wrapper(wrapper, f)
return wrapper
hashable_field = MovedFunctionDeprecationWrapper(obj_array_to_hashable)
# {{{ workarounds for https://github.com/numpy/numpy/issues/1740
def obj_array_real(ary):
return rec_obj_array_vectorize(lambda x: x.real, ary)
def obj_array_equal(a, b):
a_is_oa = is_obj_array(a)
assert a_is_oa == is_obj_array(b)
if a_is_oa:
return np.array_equal(a, b)
else:
return a == b
def obj_array_imag(ary):
return rec_obj_array_vectorize(lambda x: x.imag, ary)
field_equal = MovedFunctionDeprecationWrapper(obj_array_equal)
def obj_array_real_copy(ary):
return rec_obj_array_vectorize(lambda x: x.real.copy(), ary)
def join_fields(*args):
res_list = []
for arg in args:
if isinstance(arg, list):
res_list.extend(arg)
elif isinstance(arg, np.ndarray):
if log_shape(arg) == ():
res_list.append(arg)
else:
res_list.extend(arg.flat)
else:
res_list.append(arg)
def obj_array_imag_copy(ary):
return rec_obj_array_vectorize(lambda x: x.imag.copy(), ary)
return make_obj_array(res_list)
# }}}
# {{{ deprecated junk
def is_obj_array(val):
warn("is_obj_array is deprecated and will go away in 2022, "
"just inline the check.", DeprecationWarning, stacklevel=2)
try:
return isinstance(val, np.ndarray) and val.dtype.char == "O"
except AttributeError:
return False
def log_shape(array):
"""Returns the "logical shape" of the array.
The "logical shape" is the shape that's left when the node-depending
dimension has been eliminated."""
dimension has been eliminated.
"""
warn("log_shape is deprecated and will go away in 2021, "
"use the actual object array shape.",
DeprecationWarning, stacklevel=2)
try:
if array.dtype.char == "O":
return array.shape
else:
return array.shape[:-1]
return array.shape[:-1]
except AttributeError:
return ()
def join_fields(*args):
warn("join_fields is deprecated and will go away in 2022, "
"use flat_obj_array", DeprecationWarning, stacklevel=2)
return flat_obj_array(*args)
def is_equal(a, b):
warn("is_equal is deprecated and will go away in 2021, "
"use numpy.array_equal", DeprecationWarning, stacklevel=2)
if is_obj_array(a):
return is_obj_array(b) and (a.shape == b.shape) and (a == b).all()
return not is_obj_array(b) and a == b
is_field_equal = is_equal
def gen_len(expr):
if is_obj_array(expr):
return len(expr)
return 1
def gen_slice(expr, slice_):
warn("gen_slice is deprecated and will go away in 2021",
DeprecationWarning, stacklevel=2)
result = expr[slice_]
if len(result) == 1:
return result[0]
return result
def obj_array_equal(a, b):
warn("obj_array_equal is deprecated and will go away in 2021, "
"use numpy.array_equal", DeprecationWarning, stacklevel=2)
a_is_oa = is_obj_array(a)
assert a_is_oa == is_obj_array(b)
if a_is_oa:
return np.array_equal(a, b)
return a == b
def to_obj_array(ary):
warn("to_obj_array is deprecated and will go away in 2021, "
"use make_obj_array", DeprecationWarning,
stacklevel=2)
ls = log_shape(ary)
result = np.empty(ls, dtype=object)
for i in np.ndindex(ls):
result[i] = ary[i]
return result
def setify_field(f):
warn("setify_field is deprecated and will go away in 2021",
DeprecationWarning, stacklevel=2)
if is_obj_array(f):
return set(f)
return {f}
def cast_field(field, dtype):
warn("cast_field is deprecated and will go away in 2021",
DeprecationWarning, stacklevel=2)
return with_object_array_or_scalar(
lambda f: f.astype(dtype), field)
def with_object_array_or_scalar(f, field, obj_array_only=False):
warn("with_object_array_or_scalar is deprecated and will go away in 2022, "
"use obj_array_vectorize", DeprecationWarning, stacklevel=2)
if obj_array_only:
if is_obj_array(field):
ls = field.shape
......@@ -147,19 +376,24 @@ def with_object_array_or_scalar(f, field, obj_array_only=False):
else:
ls = log_shape(field)
if ls != ():
from pytools import indices_in_shape
result = np.zeros(ls, dtype=object)
for i in indices_in_shape(ls):
for i in np.ndindex(ls):
result[i] = f(field[i])
return result
else:
return f(field)
return f(field)
as_oarray_func = decorator(with_object_array_or_scalar)
def as_oarray_func(f):
wrapper = partial(with_object_array_or_scalar, f)
update_wrapper(wrapper, f)
return wrapper
def with_object_array_or_scalar_n_args(f, *args):
warn("with_object_array_or_scalar_n_args is deprecated and "
"will go away in 2022, "
"use obj_array_vectorize_n_args", DeprecationWarning, stacklevel=2)
oarray_arg_indices = []
for i, arg in enumerate(args):
if is_obj_array(arg):
......@@ -172,39 +406,47 @@ def with_object_array_or_scalar_n_args(f, *args):
ls = log_shape(args[leading_oa_index])
if ls != ():
from pytools import indices_in_shape
result = np.zeros(ls, dtype=object)
new_args = list(args)
for i in indices_in_shape(ls):
for i in np.ndindex(ls):
for arg_i in oarray_arg_indices:
new_args[arg_i] = args[arg_i][i]
result[i] = f(*new_args)
return result
else:
return f(*args)
as_oarray_func_n_args = decorator(with_object_array_or_scalar_n_args)
return f(*args)
def cast_field(field, dtype):
return with_object_array_or_scalar(
lambda f: f.astype(dtype), field)
def as_oarray_func_n_args(f):
wrapper = partial(with_object_array_or_scalar_n_args, f)
update_wrapper(wrapper, f)
return wrapper
def oarray_real(ary):
return with_object_array_or_scalar(lambda x: x.real, ary)
warn("oarray_real is deprecated and will go away in 2022, "
"use obj_array_real", DeprecationWarning, stacklevel=2)
return obj_array_real(ary)
def oarray_imag(ary):
return with_object_array_or_scalar(lambda x: x.imag, ary)
warn("oarray_imag is deprecated and will go away in 2022, "
"use obj_array_imag", DeprecationWarning, stacklevel=2)
return obj_array_imag(ary)
def oarray_real_copy(ary):
return with_object_array_or_scalar(lambda x: x.real.copy(), ary)
warn("oarray_real_copy is deprecated and will go away in 2022, "
"use obj_array_real_copy", DeprecationWarning, stacklevel=2)
return obj_array_real_copy(ary)
def oarray_imag_copy(ary):
return with_object_array_or_scalar(lambda x: x.imag.copy(), ary)
warn("oarray_imag_copy is deprecated and will go away in 2022, "
"use obj_array_imag_copy", DeprecationWarning, stacklevel=2)
return obj_array_imag_copy(ary)
# }}}
# vim: foldmethod=marker
from __future__ import annotations
"""Generic persistent, concurrent dictionary-like facility."""
from __future__ import division, with_statement, absolute_import
__copyright__ = """
Copyright (C) 2011,2014 Andreas Kloeckner
......@@ -27,23 +29,51 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import logging
import os
import pickle
import sqlite3
import sys
from collections.abc import Callable, Iterator, Mapping
from dataclasses import fields as dc_fields, is_dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from warnings import warn
class RecommendedHashNotFoundWarning(UserWarning):
pass
try:
import collections.abc as abc
from siphash24 import siphash13 as _default_hash
except ImportError:
# Python 2
import collections as abc
warn("Unable to import recommended hash 'siphash24.siphash13', "
"falling back to 'hashlib.sha256'. "
"Run 'python3 -m pip install siphash24' to install "
"the recommended hash.",
RecommendedHashNotFoundWarning, stacklevel=1)
from hashlib import sha256 as _default_hash
import os
import shutil
import sys
import errno
if TYPE_CHECKING:
from _typeshed import ReadableBuffer
from typing_extensions import Self
try:
import attrs
except ModuleNotFoundError:
_HAS_ATTRS = False
else:
_HAS_ATTRS = True
import six
logger = logging.getLogger(__name__)
# NOTE: not always available so they get hardcoded here
SQLITE_BUSY = getattr(sqlite3, "SQLITE_BUSY", 5)
SQLITE_CONSTRAINT_PRIMARYKEY = getattr(sqlite3, "SQLITE_CONSTRAINT_PRIMARYKEY", 1555)
__doc__ = """
Persistent Hashing and Persistent Dictionaries
==============================================
......@@ -54,373 +84,374 @@ valid across interpreter invocations, unlike Python's built-in hashes.
This module also provides a disk-backed dictionary that uses persistent hashing.
.. autoexception:: NoSuchEntryError
.. autoexception:: NoSuchEntryCollisionError
.. autoexception:: ReadOnlyEntryError
.. autoexception:: CollisionWarning
.. autoclass:: Hash
.. autoclass:: KeyBuilder
.. autoclass:: PersistentDict
.. autoclass:: WriteOncePersistentDict
"""
try:
import hashlib
new_hash = hashlib.sha256
except ImportError:
# for Python << 2.5
import sha
new_hash = sha.new
def _make_dir_recursively(dir_):
try:
os.makedirs(dir_)
except OSError as e:
from errno import EEXIST
if e.errno != EEXIST:
raise
Internal stuff that is only here because the documentation tool wants it
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. class:: K
def update_checksum(checksum, obj):
if isinstance(obj, six.text_type):
checksum.update(obj.encode("utf8"))
else:
checksum.update(obj)
A type variable for the key type of a :class:`PersistentDict`.
.. class:: V
# {{{ cleanup managers
class CleanupBase(object):
pass
class CleanupManager(CleanupBase):
def __init__(self):
self.cleanups = []
A type variable for the value type of a :class:`PersistentDict`.
"""
def register(self, c):
self.cleanups.insert(0, c)
def clean_up(self):
for c in self.cleanups:
c.clean_up()
# {{{ key generation
def error_clean_up(self):
for c in self.cleanups:
c.error_clean_up()
class Hash(Protocol):
"""A protocol for the hashes from :mod:`hashlib`.
.. automethod:: update
.. automethod:: digest
.. automethod:: hexdigest
.. automethod:: copy
"""
def update(self, data: ReadableBuffer) -> None:
...
class LockManager(CleanupBase):
def __init__(self, cleanup_m, lock_file, stacklevel=0):
self.lock_file = lock_file
def digest(self) -> bytes:
...
attempts = 0
while True:
try:
self.fd = os.open(self.lock_file,
os.O_CREAT | os.O_WRONLY | os.O_EXCL)
break
except OSError:
pass
def hexdigest(self) -> str:
...
from time import sleep
sleep(1)
def copy(self) -> Self:
...
attempts += 1
if attempts > 10:
from warnings import warn
warn("could not obtain lock--delete '%s' if necessary"
% self.lock_file,
stacklevel=1 + stacklevel)
if attempts > 3 * 60:
raise RuntimeError("waited more than three minutes "
"on the lock file '%s'"
"--something is wrong" % self.lock_file)
class KeyBuilder:
"""A (stateless) object that computes persistent hashes of objects fed to it.
Subclassing this class permits customizing the computation of hash keys.
cleanup_m.register(self)
This class follows the same general rules as Python's built-in hashing:
def clean_up(self):
os.close(self.fd)
os.unlink(self.lock_file)
- Only immutable objects can be hashed.
- If two objects compare equal, they must hash to the same value.
- Objects with the same hash may or may not compare equal.
def error_clean_up(self):
pass
In addition, hashes computed with :class:`KeyBuilder` have the following
properties:
- The hash is persistent across interpreter invocations.
- The hash is the same across different Python versions and platforms.
- The hash is invariant with respect to :envvar:`PYTHONHASHSEED`.
- Hashes are computed using functionality from :mod:`hashlib`.
class ItemDirManager(CleanupBase):
def __init__(self, cleanup_m, path, delete_on_error):
from os.path import isdir
Key builders of this type are used by :class:`PersistentDict`, but
other uses are entirely allowable.
self.existed = isdir(path)
self.path = path
self.delete_on_error = delete_on_error
.. automethod:: __call__
.. automethod:: rec
.. staticmethod:: new_hash()
cleanup_m.register(self)
Return a new hash instance following the protocol of the ones
from :mod:`hashlib`. This will permit switching to different
hash algorithms in the future. Subclasses are expected to use
this to create new hashes. Not doing so is deprecated and
may stop working as early as 2022.
def reset(self):
try:
shutil.rmtree(self.path)
except OSError as e:
if e.errno != errno.ENOENT:
raise
.. versionadded:: 2021.2
def mkdir(self):
from os import mkdir
try:
mkdir(self.path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
.. note::
def clean_up(self):
pass
def error_clean_up(self):
if self.delete_on_error:
self.reset()
Some key-building uses system byte order, so the built keys may not match
across different systems. It would be desirable to fix this, but this is
not yet done.
"""
# }}}
# this exists so that we can (conceivably) switch algorithms at some point
# down the road
new_hash: Callable[..., Hash] = _default_hash
def rec(self, key_hash: Hash, key: Any) -> Hash:
"""
:arg key_hash: the hash object to be updated with the hash of *key*.
:arg key: the (immutable) Python object to be hashed.
:returns: the updated *key_hash*
# {{{ key generation
.. versionchanged:: 2021.2
class KeyBuilder(object):
def rec(self, key_hash, key):
digest = None
Now returns the updated *key_hash*.
"""
try:
digest = key._pytools_persistent_hash_digest # noqa pylint:disable=protected-access
except AttributeError:
pass
digest = getattr(key, "_pytools_persistent_hash_digest", None)
if digest is None:
if digest is None and not isinstance(key, type):
try:
method = key.update_persistent_hash
except AttributeError:
pass
else:
inner_key_hash = new_hash()
inner_key_hash = self.new_hash()
method(inner_key_hash, self)
digest = inner_key_hash.digest()
if digest is None:
tp = type(key)
tname = tp.__name__
method = None
try:
method = getattr(self, "update_for_"+type(key).__name__)
method = getattr(self, "update_for_"+tname)
except AttributeError:
pass
else:
inner_key_hash = new_hash()
if "numpy" in sys.modules:
import numpy as np
# Hashing numpy dtypes
if (
# Handling numpy >= 1.20, for which
# type(np.dtype("float32")) -> "dtype[float32]"
tname.startswith("dtype[")
# Handling numpy >= 1.25, for which
# type(np.dtype("float32")) -> "Float32DType"
or tname.endswith("DType")
):
if isinstance(key, np.dtype):
method = self.update_for_specific_dtype
# Hashing numpy scalars
elif isinstance(key, np.number | np.bool_):
# Non-numpy scalars are handled above in the try block.
method = self.update_for_numpy_scalar
if method is None:
if issubclass(tp, Enum):
method = self.update_for_enum
elif is_dataclass(tp):
method = self.update_for_dataclass
elif _HAS_ATTRS and attrs.has(tp):
method = self.update_for_attrs
if method is not None:
inner_key_hash = self.new_hash()
method(inner_key_hash, key)
digest = inner_key_hash.digest()
if digest is None:
raise TypeError("unsupported type for persistent hash keying: %s"
% type(key))
raise TypeError(
f"unsupported type for persistent hash keying: {type(key)}")
if not isinstance(key, type):
try:
key._pytools_persistent_hash_digest = digest # noqa pylint:disable=protected-access
object.__setattr__(key, "_pytools_persistent_hash_digest", digest)
except AttributeError:
pass
except TypeError:
pass
key_hash.update(digest)
return key_hash
def __call__(self, key):
key_hash = new_hash()
def __call__(self, key: Any) -> str:
"""Return the hash of *key*."""
key_hash = self.new_hash()
self.rec(key_hash, key)
return key_hash.hexdigest()
# {{{ updaters
@staticmethod
def update_for_int(key_hash, key):
key_hash.update(str(key).encode("utf8"))
update_for_long = update_for_int
update_for_bool = update_for_int
# NOTE: None of these should be static or classmethods. While Python itself is
# perfectly OK with overriding those with 'normal' methods, type checkers
# understandably don't like it.
@staticmethod
def update_for_float(key_hash, key):
key_hash.update(repr(key).encode("utf8"))
def update_for_type(self, key_hash: Hash, key: type) -> None:
key_hash.update(
f"{key.__module__}.{key.__qualname__}.{key.__name__}".encode())
if sys.version_info >= (3,):
@staticmethod
def update_for_str(key_hash, key):
key_hash.update(key.encode('utf8'))
update_for_ABCMeta = update_for_type
@staticmethod
def update_for_bytes(key_hash, key):
key_hash.update(key)
else:
@staticmethod
def update_for_str(key_hash, key):
key_hash.update(key)
@staticmethod
def update_for_unicode(key_hash, key):
key_hash.update(key.encode('utf8'))
def update_for_int(self, key_hash: Hash, key: int) -> None:
sz = 8
while True:
try:
# Must match system byte order so that numpy and this
# generate the same string of bytes.
# https://github.com/inducer/pytools/issues/259
key_hash.update(key.to_bytes(sz, byteorder=sys.byteorder, signed=True))
return
except OverflowError:
sz *= 2
def update_for_enum(self, key_hash: Hash, key: Enum) -> None:
self.update_for_str(key_hash, str(key))
def update_for_bool(self, key_hash: Hash, key: bool) -> None:
key_hash.update(str(key).encode("utf8"))
def update_for_tuple(self, key_hash, key):
for obj_i in key:
self.rec(key_hash, obj_i)
def update_for_float(self, key_hash: Hash, key: float) -> None:
key_hash.update(key.hex().encode("utf8"))
def update_for_frozenset(self, key_hash, key):
for set_key in sorted(key):
self.rec(key_hash, set_key)
def update_for_complex(self, key_hash: Hash, key: float) -> None:
key_hash.update(repr(key).encode("utf-8"))
@staticmethod
def update_for_NoneType(key_hash, key): # noqa
del key
key_hash.update("<None>".encode('utf8'))
def update_for_str(self, key_hash: Hash, key: str) -> None:
key_hash.update(key.encode("utf8"))
@staticmethod
def update_for_dtype(key_hash, key):
key_hash.update(key.str.encode('utf8'))
def update_for_bytes(self, key_hash: Hash, key: bytes) -> None:
key_hash.update(key)
# }}}
def update_for_tuple(self, key_hash: Hash, key: tuple[Any, ...]) -> None:
for obj_i in key:
self.rec(key_hash, obj_i)
# }}}
def update_for_frozenset(self, key_hash: Hash, key: frozenset[Any]) -> None:
from pytools import unordered_hash
unordered_hash(
key_hash,
(self.rec(self.new_hash(), key_i).digest() for key_i in key),
hash_constructor=self.new_hash)
# {{{ lru cache
update_for_FrozenOrderedSet = update_for_frozenset
class _LinkedList(object):
"""The list operates on nodes of the form [value, leftptr, rightpr]. To create a
node of this form you can use `LinkedList.new_node().`
def update_for_NoneType(self, key_hash: Hash, key: None) -> None:
del key
key_hash.update(b"<None>")
def update_for_dtype(self, key_hash: Hash, key: Any) -> None:
key_hash.update(key.str.encode("utf8"))
# Handling numpy >= 1.20, for which
# type(np.dtype("float32")) -> "dtype[float32]"
# Introducing this method allows subclasses to specially handle all those
# dtypes.
def update_for_specific_dtype(self, key_hash: Hash, key: Any) -> None:
key_hash.update(key.str.encode("utf8"))
def update_for_numpy_scalar(self, key_hash: Hash, key: Any) -> None:
import numpy as np
if hasattr(np, "complex256") and key.dtype == np.dtype("complex256"):
key_hash.update(repr(complex(key)).encode("utf8"))
elif hasattr(np, "float128") and key.dtype == np.dtype("float128"):
key_hash.update(repr(float(key)).encode("utf8"))
else:
key_hash.update(np.array(key).tobytes())
Supports inserting at the left and deleting from an arbitrary location.
"""
def __init__(self):
self.count = 0
self.head = None
self.end = None
def update_for_dataclass(self, key_hash: Hash, key: Any) -> None:
self.rec(key_hash, f"{type(key).__qualname__}.{type(key).__name__}")
@staticmethod
def new_node(element):
return [element, None, None]
for fld in dc_fields(key):
self.rec(key_hash, fld.name)
self.rec(key_hash, getattr(key, fld.name, None))
def __len__(self):
return self.count
def update_for_attrs(self, key_hash: Hash, key: Any) -> None:
self.rec(key_hash, f"{type(key).__qualname__}.{type(key).__name__}")
def appendleft_node(self, node):
self.count += 1
for fld in attrs.fields(key.__class__):
self.rec(key_hash, fld.name)
self.rec(key_hash, getattr(key, fld.name, None))
if self.head is None:
self.head = self.end = node
return
def update_for_frozendict(self, key_hash: Hash, key: Mapping[Any, Any]) -> None:
from pytools import unordered_hash
self.head[1] = node
node[2] = self.head
unordered_hash(
key_hash,
(self.rec(self.new_hash(), (k, v)).digest() for k, v in key.items()),
hash_constructor=self.new_hash)
self.head = node
update_for_immutabledict = update_for_frozendict
update_for_constantdict = update_for_frozendict
update_for_PMap = update_for_frozendict
update_for_Map = update_for_frozendict
def pop_node(self):
end = self.end
self.remove_node(end)
return end
# {{{ date, time, datetime, timezone
def remove_node(self, node):
self.count -= 1
def update_for_date(self, key_hash: Hash, key: Any) -> None:
# 'date' has no timezone information; it is always naive
self.rec(key_hash, key.isoformat())
if self.head is self.end:
assert node is self.head
self.head = self.end = None
return
def update_for_time(self, key_hash: Hash, key: Any) -> None:
# 'time' should differentiate between naive and aware
import datetime
left = node[1]
right = node[2]
# Convert to datetime object
self.rec(key_hash, datetime.datetime.combine(datetime.date.min, key))
self.rec(key_hash, "<time>")
if left is None:
self.head = right
else:
left[2] = right
def update_for_datetime(self, key_hash: Hash, key: Any) -> None:
# 'datetime' should differentiate between naive and aware
if right is None:
self.end = left
# https://docs.python.org/3.11/library/datetime.html#determining-if-an-object-is-aware-or-naive
if key.tzinfo is not None and key.tzinfo.utcoffset(key) is not None:
self.rec(key_hash, key.timestamp())
self.rec(key_hash, "<aware>")
else:
right[1] = left
from datetime import timezone
self.rec(key_hash, key.replace(tzinfo=timezone.utc).timestamp())
self.rec(key_hash, "<naive>")
node[1] = node[2] = None
def update_for_timezone(self, key_hash: Hash, key: Any) -> None:
self.rec(key_hash, repr(key))
# }}}
class _LRUCache(abc.MutableMapping):
"""A mapping that keeps at most *maxsize* items with an LRU replacement policy.
"""
def __init__(self, maxsize):
self.lru_order = _LinkedList()
self.maxsize = maxsize
self.cache = {}
def __delitem__(self, item):
node = self.cache[item]
self.lru_order.remove_node(node)
del self.cache[item]
def __getitem__(self, item):
node = self.cache[item]
self.lru_order.remove_node(node)
self.lru_order.appendleft_node(node)
# A linked list node contains a tuple of the form (item, value).
return node[0][1]
def update_for_function(self, key_hash: Hash, key: Any) -> None:
self.rec(key_hash, key.__module__ + key.__qualname__)
def __contains__(self, item):
return item in self.cache
if key.__closure__:
self.rec(key_hash, tuple(c.cell_contents for c in key.__closure__))
def __iter__(self):
return iter(self.cache)
# }}}
def __len__(self):
return len(self.cache)
# }}}
def clear(self):
self.cache.clear()
self.lru_order = _LinkedList()
def __setitem__(self, item, value):
if self.maxsize < 1:
return
# {{{ top-level
try:
node = self.cache[item]
self.lru_order.remove_node(node)
except KeyError:
if len(self.lru_order) >= self.maxsize:
# Make room for new elements.
end_node = self.lru_order.pop_node()
del self.cache[end_node[0][0]]
class NoSuchEntryError(KeyError):
"""Raised when an entry is not found in a :class:`PersistentDict`."""
node = self.lru_order.new_node((item, value))
self.cache[item] = node
self.lru_order.appendleft_node(node)
class NoSuchEntryCollisionError(NoSuchEntryError):
"""Raised when an entry is not found in a :class:`PersistentDict`, but it
contains an entry with the same hash key (hash collision)."""
assert len(self.cache) == len(self.lru_order), \
(len(self.cache), len(self.lru_order))
assert len(self.lru_order) <= self.maxsize
# }}}
class ReadOnlyEntryError(KeyError):
"""Raised when an attempt is made to overwrite an entry in a
:class:`WriteOncePersistentDict`."""
# {{{ top-level
class CollisionWarning(UserWarning):
"""Warning raised when a collision is detected in a :class:`PersistentDict`."""
class NoSuchEntryError(KeyError):
pass
def __getattr__(name: str) -> Any:
if name in ("NoSuchEntryInvalidKeyError",
"NoSuchEntryInvalidContentsError"):
warn(f"pytools.persistent_dict.{name} has been removed.", stacklevel=2)
return NoSuchEntryError
class ReadOnlyEntryError(KeyError):
pass
raise AttributeError(name)
class CollisionWarning(UserWarning):
pass
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
class _PersistentDictBase(object):
def __init__(self, identifier, key_builder=None, container_dir=None):
class _PersistentDictBase(Mapping[K, V]):
def __init__(self,
identifier: str,
key_builder: KeyBuilder | None = None,
container_dir: str | None = None,
enable_wal: bool = False,
safe_sync: bool | None = None) -> None:
self.identifier = identifier
self.conn = None
if key_builder is None:
key_builder = KeyBuilder()
......@@ -429,256 +460,314 @@ class _PersistentDictBase(object):
from os.path import join
if container_dir is None:
import appdirs
container_dir = join(
appdirs.user_cache_dir("pytools", "pytools"),
"pdict-v2-%s-py%s" % (
identifier,
".".join(str(i) for i in sys.version_info),))
self.container_dir = container_dir
import platformdirs
self._make_container_dir()
@staticmethod
def _warn(msg, category=UserWarning, stacklevel=0):
from warnings import warn
warn(msg, category, stacklevel=1 + stacklevel)
def store_if_not_present(self, key, value, _stacklevel=0):
self.store(key, value, _skip_if_present=True, _stacklevel=1 + _stacklevel)
def store(self, key, value, _skip_if_present=False, _stacklevel=0):
raise NotImplementedError()
def fetch(self, key, _stacklevel=0):
raise NotImplementedError()
@staticmethod
def _read(path):
from six.moves.cPickle import load
with open(path, "rb") as inf:
return load(inf)
if sys.platform == "darwin" and os.getenv("XDG_CACHE_HOME") is not None:
# platformdirs does not handle XDG_CACHE_HOME on macOS
# https://github.com/platformdirs/platformdirs/issues/269
container_dir = join(os.getenv("XDG_CACHE_HOME"), "pytools")
else:
container_dir = platformdirs.user_cache_dir("pytools", "pytools")
@staticmethod
def _write(path, value):
from six.moves.cPickle import dump, HIGHEST_PROTOCOL
with open(path, "wb") as outf:
dump(value, outf, protocol=HIGHEST_PROTOCOL)
self.filename = join(container_dir, f"pdict-v5-{identifier}"
+ ".".join(str(i) for i in sys.version_info)
+ ".sqlite")
def _item_dir(self, hexdigest_key):
from os.path import join
return join(self.container_dir, hexdigest_key)
def _key_file(self, hexdigest_key):
from os.path import join
return join(self._item_dir(hexdigest_key), "key")
self.container_dir = container_dir
self._make_container_dir()
def _contents_file(self, hexdigest_key):
from os.path import join
return join(self._item_dir(hexdigest_key), "contents")
from threading import Lock
self.mutex = Lock()
# * isolation_level=None: enable autocommit mode
# https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
# * check_same_thread=False: thread-level concurrency is handled by the
# mutex above
self.conn = sqlite3.connect(self.filename,
isolation_level=None,
check_same_thread=False)
self._exec_sql(
"CREATE TABLE IF NOT EXISTS dict "
"(keyhash TEXT NOT NULL PRIMARY KEY, key_value TEXT NOT NULL)"
)
# https://www.sqlite.org/wal.html
if enable_wal:
self._exec_sql("PRAGMA journal_mode = 'WAL'")
# Note: the following configuration values were taken mostly from litedict:
# https://github.com/litements/litedict/blob/377603fa597453ffd9997186a493ed4fd23e5399/litedict.py#L67-L70
# Use in-memory temp store
# https://www.sqlite.org/pragma.html#pragma_temp_store
self._exec_sql("PRAGMA temp_store = 'MEMORY'")
# fsync() can be extremely slow on some systems.
# See https://github.com/inducer/pytools/issues/227 for context.
# https://www.sqlite.org/pragma.html#pragma_synchronous
if safe_sync is None or safe_sync:
if safe_sync is None:
warn(f"pytools.persistent_dict '{identifier}': "
"enabling safe_sync as default. "
"This provides strong protection against data loss, "
"but can be unnecessarily expensive for use cases such as "
"caches."
"Pass 'safe_sync=False' if occasional data loss is tolerable. "
"Pass 'safe_sync=True' to suppress this warning.",
stacklevel=3)
self._exec_sql("PRAGMA synchronous = 'NORMAL'")
else:
self._exec_sql("PRAGMA synchronous = 'OFF'")
def _lock_file(self, hexdigest_key):
from os.path import join
return join(self.container_dir, str(hexdigest_key) + ".lock")
# 64 MByte of cache
# https://www.sqlite.org/pragma.html#pragma_cache_size
self._exec_sql("PRAGMA cache_size = -64000")
def _make_container_dir(self):
_make_dir_recursively(self.container_dir)
def __del__(self) -> None:
with self.mutex:
if self.conn:
self.conn.close()
def _collision_check(self, key, stored_key, _stacklevel):
def _collision_check(self, key: K, stored_key: K) -> None:
if stored_key != key:
# Key collision, oh well.
self._warn("%s: key collision in cache at '%s' -- these are "
"sufficiently unlikely that they're often "
"indicative of a broken hash key implementation "
"(that is not considering some elements relevant "
"for equality comparison)"
% (self.identifier, self.container_dir),
warn(f"{self.identifier}: key collision in cache at "
f"'{self.container_dir}' -- these are sufficiently unlikely "
"that they're often indicative of a broken hash key "
"implementation (that is not considering some elements "
"relevant for equality comparison)",
CollisionWarning,
1 + _stacklevel)
stacklevel=3
)
# This is here so we can step through equality comparison to
# see what is actually non-equal.
stored_key == key # pylint:disable=pointless-statement
raise NoSuchEntryError(key)
stored_key == key # noqa: B015
raise NoSuchEntryCollisionError(key)
def __getitem__(self, key):
return self.fetch(key, _stacklevel=1)
def _exec_sql(self, *args: Any) -> sqlite3.Cursor:
def execute() -> sqlite3.Cursor:
assert self.conn is not None
return self.conn.execute(*args)
def __setitem__(self, key, value):
self.store(key, value, _stacklevel=1)
cursor = self._exec_sql_fn(execute)
if not isinstance(cursor, sqlite3.Cursor):
raise RuntimeError("Failed to execute SQL statement")
def clear(self):
try:
shutil.rmtree(self.container_dir)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return cursor
self._make_container_dir()
def _exec_sql_fn(self, fn: Callable[[], T]) -> T | None:
n = 0
class WriteOncePersistentDict(_PersistentDictBase):
"""A concurrent disk-backed dictionary that disallows overwriting/deletion.
with self.mutex:
while True:
n += 1
try:
return fn()
except sqlite3.OperationalError as e:
# If the database is busy, retry
if (hasattr(e, "sqlite_errorcode")
and e.sqlite_errorcode != SQLITE_BUSY):
raise
if n % 20 == 0:
warn(f"PersistentDict: database '{self.filename}' busy, {n} "
"retries", stacklevel=3)
else:
break
def store_if_not_present(self, key: K, value: V) -> None:
"""Store (*key*, *value*) if *key* is not already present."""
self.store(key, value, _skip_if_present=True)
def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
"""Store (*key*, *value*) in the dictionary."""
raise NotImplementedError
def fetch(self, key: K) -> V:
"""Return the value associated with *key* in the dictionary."""
raise NotImplementedError
def _make_container_dir(self) -> None:
"""Create the container directory to store the dictionary."""
os.makedirs(self.container_dir, exist_ok=True)
def __getitem__(self, key: K) -> V:
"""Return the value associated with *key* in the dictionary."""
return self.fetch(key)
def __setitem__(self, key: K, value: V) -> None:
"""Store (*key*, *value*) in the dictionary."""
self.store(key, value)
def __len__(self) -> int:
"""Return the number of entries in the dictionary."""
result, = next(self._exec_sql("SELECT COUNT(*) FROM dict"))
assert isinstance(result, int)
return result
def __iter__(self) -> Iterator[K]:
"""Return an iterator over the keys in the dictionary."""
return self.keys()
def keys(self) -> Iterator[K]: # type: ignore[override]
"""Return an iterator over the keys in the dictionary."""
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])[0]
def values(self) -> Iterator[V]: # type: ignore[override]
"""Return an iterator over the values in the dictionary."""
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])[1]
def items(self) -> Iterator[tuple[K, V]]: # type: ignore[override]
"""Return an iterator over the items in the dictionary."""
for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
yield pickle.loads(row[0])
def nbytes(self) -> int:
"""Return the size of the dictionary in bytes."""
result, = next(self._exec_sql("SELECT page_size * page_count FROM "
"pragma_page_size(), pragma_page_count()"))
assert isinstance(result, int)
return result
def __repr__(self) -> str:
"""Return a string representation of the dictionary."""
return f"{type(self).__name__}({self.filename}, nitems={len(self)})"
def clear(self) -> None:
"""Remove all entries from the dictionary."""
self._exec_sql("DELETE FROM dict")
class WriteOncePersistentDict(_PersistentDictBase[K, V]):
"""A concurrent disk-backed dictionary that disallows overwriting/
deletion (but allows removing all entries).
Compared with :class:`PersistentDict`, this class has faster
retrieval times.
retrieval times because it uses an LRU cache to cache entries in memory.
.. note::
This class intentionally does not store all values with a certain
key, based on the assumption that key conflicts are highly unlikely,
and if they occur, almost always due to a bug in the hash key
generation code (:class:`KeyBuilder`).
.. automethod:: __init__
.. automethod:: __getitem__
.. automethod:: __setitem__
.. automethod:: clear
.. automethod:: clear_in_mem_cache
.. automethod:: store
.. automethod:: store_if_not_present
.. automethod:: fetch
"""
def __init__(self, identifier, key_builder=None, container_dir=None,
in_mem_cache_size=256):
def __init__(self, identifier: str,
key_builder: KeyBuilder | None = None,
container_dir: str | None = None,
*,
enable_wal: bool = False,
safe_sync: bool | None = None,
in_mem_cache_size: int = 256) -> None:
"""
:arg identifier: a file-name-compatible string identifying this
:arg identifier: a filename-compatible string identifying this
dictionary
:arg key_builder: a subclass of :class:`KeyBuilder`
:arg container_dir: the directory in which to store this
dictionary. If ``None``, the default cache directory from
:func:`platformdirs.user_cache_dir` is used
:arg enable_wal: enable write-ahead logging (WAL) mode. This mode
is faster than the default rollback journal mode, but it is
not compatible with network filesystems.
:arg in_mem_cache_size: retain an in-memory cache of up to
*in_mem_cache_size* items
*in_mem_cache_size* items (with an LRU replacement policy)
"""
_PersistentDictBase.__init__(self, identifier, key_builder, container_dir)
self._cache = _LRUCache(in_mem_cache_size)
def _spin_until_removed(self, lock_file, stacklevel):
from os.path import exists
super().__init__(identifier,
key_builder=key_builder,
container_dir=container_dir,
enable_wal=enable_wal,
safe_sync=safe_sync)
attempts = 0
while exists(lock_file):
from time import sleep
sleep(1)
from functools import lru_cache
attempts += 1
self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch_uncached)
if attempts > 10:
self._warn("waiting until unlocked--delete '%s' if necessary"
% lock_file, stacklevel=1 + stacklevel)
def clear_in_mem_cache(self) -> None:
"""
Clear the in-memory cache of this dictionary.
if attempts > 3 * 60:
raise RuntimeError("waited more than three minutes "
"on the lock file '%s'"
"--something is wrong" % lock_file)
.. versionadded:: 2023.1.1
"""
self._fetch.cache_clear()
def store(self, key, value, _skip_if_present=False, _stacklevel=0):
hexdigest_key = self.key_builder(key)
def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
keyhash = self.key_builder(key)
v = pickle.dumps((key, value))
cleanup_m = CleanupManager()
try:
if _skip_if_present:
self._exec_sql("INSERT OR IGNORE INTO dict VALUES (?, ?)",
(keyhash, v))
else:
try:
LockManager(cleanup_m, self._lock_file(hexdigest_key),
1 + _stacklevel)
item_dir_m = ItemDirManager(
cleanup_m, self._item_dir(hexdigest_key),
delete_on_error=False)
if item_dir_m.existed:
if _skip_if_present:
return
raise ReadOnlyEntryError(key)
item_dir_m.mkdir()
key_path = self._key_file(hexdigest_key)
value_path = self._contents_file(hexdigest_key)
self._write(value_path, value)
self._write(key_path, key)
logger.debug("%s: disk cache store [key=%s]",
self.identifier, hexdigest_key)
except Exception:
cleanup_m.error_clean_up()
raise
finally:
cleanup_m.clean_up()
def fetch(self, key, _stacklevel=0):
hexdigest_key = self.key_builder(key)
# {{{ in memory cache
self._exec_sql("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
except sqlite3.IntegrityError as e:
if hasattr(e, "sqlite_errorcode"):
if e.sqlite_errorcode == SQLITE_CONSTRAINT_PRIMARYKEY:
raise ReadOnlyEntryError("WriteOncePersistentDict, "
"tried overwriting key") from e
raise
raise ReadOnlyEntryError("WriteOncePersistentDict, "
"tried overwriting key") from e
def _fetch_uncached(self, keyhash: str) -> tuple[K, V]:
# This method is separate from fetch() to allow for LRU caching
def fetch_inner() -> tuple[Any] | None:
assert self.conn is not None
# This is separate from fetch() so that the mutex covers the
# fetchone() call
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
res = c.fetchone()
assert res is None or isinstance(res, tuple)
return res
row = self._exec_sql_fn(fetch_inner)
if row is None:
raise KeyError
key, value = pickle.loads(row[0])
return key, value
def fetch(self, key: K) -> V:
keyhash = self.key_builder(key)
try:
stored_key, stored_value = self._cache[hexdigest_key]
except KeyError:
pass
stored_key, value = self._fetch(keyhash)
except KeyError as err:
raise NoSuchEntryError(key) from err
else:
logger.debug("%s: in mem cache hit [key=%s]",
self.identifier, hexdigest_key)
self._collision_check(key, stored_key, 1 + _stacklevel)
return stored_value
# }}}
# {{{ check path exists and is unlocked
item_dir = self._item_dir(hexdigest_key)
from os.path import isdir
if not isdir(item_dir):
logger.debug("%s: disk cache miss [key=%s]",
self.identifier, hexdigest_key)
raise NoSuchEntryError(key)
self._collision_check(key, stored_key)
return value
lock_file = self._lock_file(hexdigest_key)
self._spin_until_removed(lock_file, 1 + _stacklevel)
def clear(self) -> None:
super().clear()
self._fetch.cache_clear()
# }}}
key_file = self._key_file(hexdigest_key)
contents_file = self._contents_file(hexdigest_key)
# Note: Unlike PersistentDict, this doesn't autodelete invalid entires,
# because that would lead to a race condition.
# {{{ load key file and do equality check
try:
read_key = self._read(key_file)
except Exception as e:
self._warn("pytools.persistent_dict.WriteOncePersistentDict(%s) "
"encountered an invalid "
"key file for key %s. Remove the directory "
"'%s' if necessary. (caught: %s)"
% (self.identifier, hexdigest_key, item_dir, str(e)),
stacklevel=1 + _stacklevel)
raise NoSuchEntryError(key)
self._collision_check(key, read_key, 1 + _stacklevel)
# }}}
logger.debug("%s: disk cache hit [key=%s]",
self.identifier, hexdigest_key)
# {{{ load contents
try:
read_contents = self._read(contents_file)
except Exception:
self._warn("pytools.persistent_dict.WriteOncePersistentDict(%s) "
"encountered an invalid "
"key file for key %s. Remove the directory "
"'%s' if necessary."
% (self.identifier, hexdigest_key, item_dir),
stacklevel=1 + _stacklevel)
raise NoSuchEntryError(key)
# }}}
self._cache[hexdigest_key] = (key, read_contents)
return read_contents
def clear(self):
_PersistentDictBase.clear(self)
self._cache.clear()
class PersistentDict(_PersistentDictBase[K, V]):
"""A concurrent disk-backed dictionary.
.. note::
class PersistentDict(_PersistentDictBase):
"""A concurrent disk-backed dictionary.
This class intentionally does not store all values with a certain
key, based on the assumption that key conflicts are highly unlikely,
and if they occur, almost always due to a bug in the hash key
generation code (:class:`KeyBuilder`).
.. automethod:: __init__
.. automethod:: __getitem__
......@@ -690,155 +779,91 @@ class PersistentDict(_PersistentDictBase):
.. automethod:: fetch
.. automethod:: remove
"""
def __init__(self, identifier, key_builder=None, container_dir=None):
def __init__(self,
identifier: str,
key_builder: KeyBuilder | None = None,
container_dir: str | None = None,
*,
enable_wal: bool = False,
safe_sync: bool | None = None) -> None:
"""
:arg identifier: a file-name-compatible string identifying this
:arg identifier: a filename-compatible string identifying this
dictionary
:arg key_builder: a subclass of :class:`KeyBuilder`
:arg container_dir: the directory in which to store this
dictionary. If ``None``, the default cache directory from
:func:`platformdirs.user_cache_dir` is used
:arg enable_wal: enable write-ahead logging (WAL) mode. This mode
is faster than the default rollback journal mode, but it is
not compatible with network filesystems.
"""
_PersistentDictBase.__init__(self, identifier, key_builder, container_dir)
super().__init__(identifier,
key_builder=key_builder,
container_dir=container_dir,
enable_wal=enable_wal,
safe_sync=safe_sync)
def store(self, key, value, _skip_if_present=False, _stacklevel=0):
hexdigest_key = self.key_builder(key)
def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
keyhash = self.key_builder(key)
v = pickle.dumps((key, value))
cleanup_m = CleanupManager()
try:
try:
LockManager(cleanup_m, self._lock_file(hexdigest_key),
1 + _stacklevel)
item_dir_m = ItemDirManager(
cleanup_m, self._item_dir(hexdigest_key),
delete_on_error=True)
if item_dir_m.existed:
if _skip_if_present:
return
item_dir_m.reset()
item_dir_m.mkdir()
key_path = self._key_file(hexdigest_key)
value_path = self._contents_file(hexdigest_key)
self._write(value_path, value)
self._write(key_path, key)
logger.debug("%s: cache store [key=%s]",
self.identifier, hexdigest_key)
except Exception:
cleanup_m.error_clean_up()
raise
finally:
cleanup_m.clean_up()
def fetch(self, key, _stacklevel=0):
hexdigest_key = self.key_builder(key)
item_dir = self._item_dir(hexdigest_key)
from os.path import isdir
if not isdir(item_dir):
logger.debug("%s: cache miss [key=%s]",
self.identifier, hexdigest_key)
raise NoSuchEntryError(key)
mode = "IGNORE" if _skip_if_present else "REPLACE"
cleanup_m = CleanupManager()
try:
try:
LockManager(cleanup_m, self._lock_file(hexdigest_key),
1 + _stacklevel)
item_dir_m = ItemDirManager(
cleanup_m, item_dir, delete_on_error=False)
key_path = self._key_file(hexdigest_key)
value_path = self._contents_file(hexdigest_key)
# {{{ load key
try:
read_key = self._read(key_path)
except Exception:
item_dir_m.reset()
self._warn("pytools.persistent_dict.PersistentDict(%s) "
"encountered an invalid "
"key file for key %s. Entry deleted."
% (self.identifier, hexdigest_key),
stacklevel=1 + _stacklevel)
raise NoSuchEntryError(key)
self._collision_check(key, read_key, 1 + _stacklevel)
self._exec_sql(f"INSERT OR {mode} INTO dict VALUES (?, ?)",
(keyhash, v))
# }}}
def fetch(self, key: K) -> V:
keyhash = self.key_builder(key)
logger.debug("%s: cache hit [key=%s]",
self.identifier, hexdigest_key)
def fetch_inner() -> tuple[Any] | None:
assert self.conn is not None
# {{{ load value
# This is separate from fetch() so that the mutex covers the
# fetchone() call
c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
(keyhash,))
res = c.fetchone()
assert res is None or isinstance(res, tuple)
return res
try:
read_contents = self._read(value_path)
except Exception:
item_dir_m.reset()
self._warn("pytools.persistent_dict.PersistentDict(%s) "
"encountered an invalid "
"key file for key %s. Entry deleted."
% (self.identifier, hexdigest_key),
stacklevel=1 + _stacklevel)
raise NoSuchEntryError(key)
row = self._exec_sql_fn(fetch_inner)
return read_contents
# }}}
except Exception:
cleanup_m.error_clean_up()
raise
finally:
cleanup_m.clean_up()
def remove(self, key, _stacklevel=0):
hexdigest_key = self.key_builder(key)
item_dir = self._item_dir(hexdigest_key)
from os.path import isdir
if not isdir(item_dir):
if row is None:
raise NoSuchEntryError(key)
cleanup_m = CleanupManager()
try:
try:
LockManager(cleanup_m, self._lock_file(hexdigest_key),
1 + _stacklevel)
item_dir_m = ItemDirManager(
cleanup_m, item_dir, delete_on_error=False)
key_file = self._key_file(hexdigest_key)
stored_key, value = pickle.loads(row[0])
self._collision_check(key, stored_key)
return cast("V", value)
# {{{ load key
def remove(self, key: K) -> None:
"""Remove the entry associated with *key* from the dictionary."""
keyhash = self.key_builder(key)
try:
read_key = self._read(key_file)
except Exception:
item_dir_m.reset()
self._warn("pytools.persistent_dict.PersistentDict(%s) "
"encountered an invalid "
"key file for key %s. Entry deleted."
% (self.identifier, hexdigest_key),
stacklevel=1 + _stacklevel)
def remove_inner() -> None:
assert self.conn is not None
self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")
try:
# This is split into SELECT/DELETE to allow for a collision check
c = self.conn.execute("SELECT key_value FROM dict WHERE "
"keyhash=?", (keyhash,))
row = c.fetchone()
if row is None:
raise NoSuchEntryError(key)
self._collision_check(key, read_key, 1 + _stacklevel)
# }}}
stored_key, _value = pickle.loads(row[0])
self._collision_check(key, stored_key)
item_dir_m.reset()
self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,))
self.conn.execute("COMMIT")
except Exception as e:
self.conn.execute("ROLLBACK")
raise e
except Exception:
cleanup_m.error_clean_up()
raise
finally:
cleanup_m.clean_up()
self._exec_sql_fn(remove_inner)
def __delitem__(self, key):
self.remove(key, _stacklevel=1)
def __delitem__(self, key: K) -> None:
"""Remove the entry associated with *key* from the dictionary."""
self.remove(key)
# }}}
......
"""OpenMPI, once intialized, prohibits forking. This helper module
"""OpenMPI, once initialized, prohibits forking. This helper module
allows the forking of *one* helper child process before OpenMPI
initializaton that can do the forking for the fork-challenged
initialization that can do the forking for the fork-challenged
parent process.
Since none of this is MPI-specific, it got parked in pytools.
Since none of this is MPI-specific, it got parked in :mod:`pytools`.
.. autoexception:: ExecError
:show-inheritance:
.. autoclass:: Forker
.. autoclass:: DirectForker
.. autoclass:: IndirectForker
.. autofunction:: enable_prefork
.. autofunction:: call
.. autofunction:: call_async
.. autofunction:: call_capture_output
.. autofunction:: wait
.. autofunction:: waitall
"""
from __future__ import absolute_import
from __future__ import annotations
import socket
from abc import ABC, abstractmethod
from subprocess import Popen
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from collections.abc import Sequence
class ExecError(OSError):
pass
class DirectForker(object):
def __init__(self):
self.apids = {}
self.count = 0
class Forker(ABC):
@abstractmethod
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass
@abstractmethod
def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
pass
@abstractmethod
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
pass
@abstractmethod
def wait(self, aid: int) -> int:
pass
@abstractmethod
def waitall(self) -> dict[int, int]:
pass
class DirectForker(Forker):
def __init__(self) -> None:
self.apids: dict[int, Popen[bytes]] = {}
self.count: int = 0
@staticmethod
def call(cmdline, cwd=None):
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
from subprocess import call as spcall
try:
return spcall(cmdline, cwd=cwd)
except OSError as e:
raise ExecError("error invoking '%s': %s"
% (" ".join(cmdline), e))
def call_async(self, cmdline, cwd=None):
from subprocess import Popen
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e
def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
try:
self.count += 1
......@@ -38,12 +83,14 @@ class DirectForker(object):
return self.count
except OSError as e:
raise ExecError("error invoking '%s': %s"
% (" ".join(cmdline), e))
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e
@staticmethod
def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
from subprocess import Popen, PIPE
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
from subprocess import PIPE, Popen
try:
popen = Popen(cmdline, cwd=cwd, stdin=PIPE, stdout=PIPE,
......@@ -51,33 +98,34 @@ class DirectForker(object):
stdout_data, stderr_data = popen.communicate()
if error_on_nonzero and popen.returncode:
raise ExecError("status %d invoking '%s': %s"
% (popen.returncode, " ".join(cmdline),
stderr_data))
raise ExecError("status {} invoking '{}': {}".format(
popen.returncode,
" ".join(cmdline),
stderr_data.decode("utf-8", errors="replace")))
return popen.returncode, stdout_data, stderr_data
except OSError as e:
raise ExecError("error invoking '%s': %s"
% (" ".join(cmdline), e))
raise ExecError(
"error invoking '{}': {}".format(" ".join(cmdline), e)) from e
def wait(self, aid):
def wait(self, aid: int) -> int:
proc = self.apids.pop(aid)
retc = proc.wait()
return retc
def waitall(self):
def waitall(self) -> dict[int, int]:
rets = {}
for aid in list(self.apids):
for aid in self.apids:
rets[aid] = self.wait(aid)
return rets
def _send_packet(sock, data):
def _send_packet(sock: socket.socket, data: object) -> None:
from pickle import dumps
from struct import pack
from six.moves.cPickle import dumps
packet = dumps(data)
......@@ -85,15 +133,14 @@ def _send_packet(sock, data):
sock.sendall(packet)
def _recv_packet(sock, who="Process", partner="other end"):
def _recv_packet(sock: socket.socket,
who: str = "Process",
partner: str = "other end") -> tuple[object, ...]:
from struct import calcsize, unpack
size_bytes_size = calcsize("I")
size_bytes = sock.recv(size_bytes_size)
if len(size_bytes) < size_bytes_size:
from warnings import warn
warn("%s exiting upon apparent death of %s" % (who, partner))
raise SystemExit
size, = unpack("I", size_bytes)
......@@ -102,11 +149,15 @@ def _recv_packet(sock, who="Process", partner="other end"):
while len(packet) < size:
packet += sock.recv(size)
from six.moves.cPickle import loads
return loads(packet)
from pickle import loads
result = loads(packet)
assert isinstance(result, tuple)
return result
def _fork_server(sock):
def _fork_server(sock: socket.socket) -> None:
# Ignore keyboard interrupts, we'll get notified by the parent.
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
......@@ -127,80 +178,96 @@ def _fork_server(sock):
func_name, args, kwargs = _recv_packet(
sock, who="Prefork server", partner="parent"
)
assert isinstance(func_name, str)
if func_name == "quit":
df.waitall()
_send_packet(sock, ("ok", None))
break
try:
result = funcs[func_name](*args, **kwargs) # type: ignore[operator]
# FIXME: Is catching all exceptions the right course of action?
except Exception as e: # pylint:disable=broad-except
_send_packet(sock, ("exception", e))
else:
try:
result = funcs[func_name](*args, **kwargs)
# FIXME: Is catching all exceptions the right course of action?
except Exception as e: # pylint:disable=broad-except
_send_packet(sock, ("exception", e))
else:
_send_packet(sock, ("ok", result))
_send_packet(sock, ("ok", result))
finally:
sock.close()
import os
os._exit(0) # pylint:disable=protected-access
os._exit(0)
class IndirectForker(object):
def __init__(self, server_pid, sock):
class IndirectForker(Forker):
def __init__(self, server_pid: int, sock: socket.socket) -> None:
self.server_pid = server_pid
self.socket = sock
import atexit
atexit.register(self._quit)
def _remote_invoke(self, name, *args, **kwargs):
def _remote_invoke(self, name: str, *args: Any, **kwargs: Any) -> object:
_send_packet(self.socket, (name, args, kwargs))
status, result = _recv_packet(
self.socket, who="Prefork client", partner="prefork server"
)
if status == "exception":
assert isinstance(result, Exception)
raise result
assert status == "ok"
return result
def _quit(self):
def _quit(self) -> None:
self._remote_invoke("quit")
from os import waitpid
waitpid(self.server_pid, 0)
def call(self, cmdline, cwd=None):
return self._remote_invoke("call", cmdline, cwd)
def call(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call", cmdline, cwd)
assert isinstance(result, int)
return result
def call_async(self, cmdline: Sequence[str], cwd: str | None = None) -> int:
result = self._remote_invoke("call_async", cmdline, cwd)
def call_async(self, cmdline, cwd=None):
return self._remote_invoke("call_async", cmdline, cwd)
assert isinstance(result, int)
return result
def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True):
return self._remote_invoke("call_capture_output", cmdline, cwd,
def call_capture_output(self,
cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True,
) -> tuple[int, bytes, bytes]:
return self._remote_invoke("call_capture_output", cmdline, cwd, # type: ignore[return-value]
error_on_nonzero)
def wait(self, aid):
return self._remote_invoke("wait", aid)
def wait(self, aid: int) -> int:
result = self._remote_invoke("wait", aid)
assert isinstance(result, int)
return result
def waitall(self) -> dict[int, int]:
result = self._remote_invoke("waitall")
def waitall(self):
return self._remote_invoke("waitall")
assert isinstance(result, dict)
return result
forker = DirectForker()
forker: Forker = DirectForker()
def enable_prefork():
global forker # pylint:disable=global-statement
def enable_prefork() -> None:
global forker
if isinstance(forker, IndirectForker):
return
from socket import socketpair
s_parent, s_child = socketpair()
s_parent, s_child = socket.socketpair()
from os import fork
fork_res = fork()
......@@ -215,21 +282,23 @@ def enable_prefork():
forker = IndirectForker(fork_res, s_parent)
def call(cmdline, cwd=None):
def call(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call(cmdline, cwd)
def call_async(cmdline, cwd=None):
def call_async(cmdline: Sequence[str], cwd: str | None = None) -> int:
return forker.call_async(cmdline, cwd)
def call_capture_output(cmdline, cwd=None, error_on_nonzero=True):
def call_capture_output(cmdline: Sequence[str],
cwd: str | None = None,
error_on_nonzero: bool = True) -> tuple[int, bytes, bytes]:
return forker.call_capture_output(cmdline, cwd, error_on_nonzero)
def wait(aid):
def wait(aid: int) -> int:
return forker.wait(aid)
def waitall():
def waitall() -> dict[int, int]:
return forker.waitall()