Skip to content
#! /bin/sh #! /bin/sh
rsync --verbose --archive --delete _build/html/* doc-upload:doc/pytools rsync --verbose --archive --delete _build/html/ doc-upload:doc/pytools
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "pytools"
version = "2024.1.21"
description = "A collection of tools for Python"
readme = "README.rst"
license = { text = "MIT" }
authors = [
{ name = "Andreas Kloeckner", email = "inform@tiker.net" },
]
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Other Audience",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Scientific/Engineering :: Visualization",
"Topic :: Software Development :: Libraries",
"Topic :: Utilities",
]
dependencies = [
"platformdirs>=2.2",
# for dataclass_transform with frozen_default
"typing-extensions>=4.5",
]
[project.optional-dependencies]
numpy = [
"numpy>=1.6",
]
test = [
"mypy",
"pytest",
"ruff",
]
siphash = [
"siphash24>=1.6",
]
[project.urls]
Documentation = "https://documen.tician.de/pytools/"
Homepage = "https://github.com/inducer/pytools/"
[tool.hatch.build.targets.sdist]
exclude = [
"/.git*",
"/doc/_build",
"/.editorconfig",
"/run-*.sh",
]
[tool.ruff]
preview = true
[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear
"C", # flake8-comprehensions
"E", # pycodestyle
"F", # pyflakes
"G", # flake8-logging-format
"I", # flake8-isort
"N", # pep8-naming
"NPY", # numpy
"Q", # flake8-quotes
"UP", # pyupgrade
"RUF", # ruff
"W", # pycodestyle
"TC",
]
extend-ignore = [
"C90", # McCabe complexity
"E221", # multiple spaces before operator
"E226", # missing whitespace around arithmetic operator
"E402", # module-level import not at top of file
"UP031", # use f-strings instead of %
"UP032", # use f-strings instead of .format
]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
inline-quotes = "double"
multiline-quotes = "double"
[tool.ruff.lint.isort]
combine-as-imports = true
known-local-folder = [
"pytools",
]
lines-after-imports = 2
required-imports = ["from __future__ import annotations"]
[tool.ruff.lint.pep8-naming]
extend-ignore-names = ["update_for_*"]
[tool.mypy]
python_version = "3.10"
ignore_missing_imports = true
warn_unused_ignores = true
# TODO: enable this at some point
# check_untyped_defs = true
[tool.typos.default]
extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$"
]
This diff is collapsed.
from __future__ import absolute_import from __future__ import annotations
import six
def _cp(src, dest): def _cp(src, dest):
...@@ -22,7 +21,7 @@ def get_timestamp(): ...@@ -22,7 +21,7 @@ def get_timestamp():
return datetime.now().strftime("%Y-%m-%d-%H%M%S") return datetime.now().strftime("%Y-%m-%d-%H%M%S")
class BatchJob(object): class BatchJob:
def __init__(self, moniker, main_file, aux_files=(), timestamp=None): def __init__(self, moniker, main_file, aux_files=(), timestamp=None):
import os import os
import os.path import os.path
...@@ -44,10 +43,9 @@ class BatchJob(object): ...@@ -44,10 +43,9 @@ class BatchJob(object):
os.makedirs(self.path) os.makedirs(self.path)
runscript = open("%s/run.sh" % self.path, "w") runscript = open(f"{self.path}/run.sh", "w")
import sys import sys
runscript.write("%s %s setup.cpy" runscript.write(f"{sys.executable} {main_file} setup.cpy")
% (sys.executable, main_file))
runscript.close() runscript.close()
from os.path import basename from os.path import basename
...@@ -65,7 +63,7 @@ class BatchJob(object): ...@@ -65,7 +63,7 @@ class BatchJob(object):
setup.close() setup.close()
class INHERIT(object): # noqa class INHERIT:
pass pass
...@@ -80,20 +78,20 @@ class GridEngineJob(BatchJob): ...@@ -80,20 +78,20 @@ class GridEngineJob(BatchJob):
from os import getenv from os import getenv
env = dict(env) env = dict(env)
for var, value in six.iteritems(env): for var, value in env.items():
if value is INHERIT: if value is INHERIT:
value = getenv(var) value = getenv(var)
args += ["-v", "%s=%s" % (var, value)] args += ["-v", f"{var}={value}"]
if memory_megs is not None: if memory_megs is not None:
args.extend(["-l", "mem=%d" % memory_megs]) args.extend(["-l", f"mem={memory_megs}"])
args.extend(extra_args) args.extend(extra_args)
subproc = Popen(["qsub"] + args + ["run.sh"], cwd=self.path) subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path)
if subproc.wait() != 0: if subproc.wait() != 0:
raise RuntimeError("Process submission of %s failed" % self.moniker) raise RuntimeError(f"Process submission of {self.moniker} failed")
class PBSJob(BatchJob): class PBSJob(BatchJob):
...@@ -106,32 +104,31 @@ class PBSJob(BatchJob): ...@@ -106,32 +104,31 @@ class PBSJob(BatchJob):
] ]
if memory_megs is not None: if memory_megs is not None:
args.extend(["-l", "pmem=%dmb" % memory_megs]) args.extend(["-l", f"pmem={memory_megs}mb"])
from os import getenv from os import getenv
env = dict(env) env = dict(env)
for var, value in six.iteritems(env): for var, value in env.items():
if value is INHERIT: if value is INHERIT:
value = getenv(var) value = getenv(var)
args += ["-v", "%s=%s" % (var, value)] args += ["-v", f"{var}={value}"]
args.extend(extra_args) args.extend(extra_args)
subproc = Popen(["qsub"] + args + ["run.sh"], cwd=self.path) subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path)
if subproc.wait() != 0: if subproc.wait() != 0:
raise RuntimeError("Process submission of %s failed" % self.moniker) raise RuntimeError(f"Process submission of {self.moniker} failed")
def guess_job_class(): def guess_job_class():
from subprocess import Popen, PIPE, STDOUT from subprocess import PIPE, STDOUT, Popen
qstat_helplines = Popen(["qstat", "--help"], qstat_helplines = Popen(["qstat", "--help"],
stdout=PIPE, stderr=STDOUT).communicate()[0].split("\n") stdout=PIPE, stderr=STDOUT).communicate()[0].split("\n")
if qstat_helplines[0].startswith("GE"): if qstat_helplines[0].startswith("GE"):
return GridEngineJob return GridEngineJob
else: return PBSJob
return PBSJob
class ConstructorPlaceholder: class ConstructorPlaceholder:
...@@ -147,11 +144,11 @@ class ConstructorPlaceholder: ...@@ -147,11 +144,11 @@ class ConstructorPlaceholder:
return self.kwargs[name] return self.kwargs[name]
def __str__(self): def __str__(self):
return "%s(%s)" % (self.classname, return "{}({})".format(self.classname,
",".join( ",".join(
[str(arg) for arg in self.args] [str(arg) for arg in self.args]
+ ["%s=%s" % (kw, repr(val)) + [f"{kw}={val!r}"
for kw, val in six.iteritems(self.kwargs)] for kw, val in self.kwargs.items()]
) )
) )
__repr__ = __str__ __repr__ = __str__
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
Tools for Source Code Generation
================================
.. autoclass:: CodeGenerator
.. autoclass:: Indentation
.. autofunction:: remove_common_indentation
"""
from typing import Any
# {{{ code generation
# loosely based on
# http://effbot.org/zone/python-code-generator.htm
class CodeGenerator:
"""Language-agnostic functionality for source code generation.
.. automethod:: extend
.. automethod:: get
.. automethod:: add_to_preamble
.. automethod:: __call__
.. automethod:: indent
.. automethod:: dedent
"""
def __init__(self) -> None:
self.preamble: list[str] = []
self.code: list[str] = []
self.level = 0
self.indent_amount = 4
def extend(self, sub_generator: CodeGenerator) -> None:
for line in sub_generator.code:
self.code.append(" "*(self.indent_amount*self.level) + line)
def get(self) -> str:
result = "\n".join(self.code)
if self.preamble:
result = "\n".join(self.preamble) + "\n" + result
return result
def add_to_preamble(self, s: str) -> None:
self.preamble.append(s)
def __call__(self, s: str) -> None:
if not s.strip():
self.code.append("")
else:
if "\n" in s:
s = remove_common_indentation(s)
for line in s.split("\n"):
self.code.append(" "*(self.indent_amount*self.level) + line)
def indent(self) -> None:
self.level += 1
def dedent(self) -> None:
if self.level == 0:
raise RuntimeError("cannot decrease indentation level")
self.level -= 1
class Indentation:
"""A context manager for indentation for use with :class:`CodeGenerator`.
.. attribute:: generator
.. automethod:: __enter__
.. automethod:: __exit__
"""
def __init__(self, generator: CodeGenerator):
self.generator = generator
def __enter__(self) -> None:
self.generator.indent()
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.generator.dedent()
# }}}
# {{{ remove common indentation
def remove_common_indentation(code: str, require_leading_newline: bool = True):
r"""Remove leading indentation from one or more lines of code.
Removes an amount of indentation equal to the indentation level of the first
nonempty line in *code*.
:param code: Input string.
:param require_leading_newline: If *True*, only remove indentation if *code*
starts with ``\n``.
:returns: A copy of *code* stripped of leading common indentation.
"""
if "\n" not in code:
return code
if require_leading_newline and not code.startswith("\n"):
return code
lines = code.split("\n")
while lines[0].strip() == "":
lines.pop(0)
while lines[-1].strip() == "":
lines.pop(-1)
if lines:
base_indent = 0
while lines[0][base_indent] in " \t":
base_indent += 1
for line in lines[1:]:
if line[:base_indent].strip():
raise ValueError("inconsistent indentation")
return "\n".join(line[base_indent:] for line in lines)
# }}}
# vim: foldmethod=marker
from __future__ import absolute_import """
.. autofunction:: estimate_order_of_convergence
.. autoclass:: EOCRecorder
.. autofunction:: stringify_eocs
.. autoclass:: PConvergenceVerifier
"""
from __future__ import annotations
import numbers
import numpy as np import numpy as np
from six.moves import range
from six.moves import zip
# {{{ eoc estimation -------------------------------------------------------------- # {{{ eoc estimation --------------------------------------------------------------
def estimate_order_of_convergence(abscissae, errors): def estimate_order_of_convergence(abscissae, errors):
"""Assuming that abscissae and errors are connected by a law of the form r"""Assuming that abscissae and errors are connected by a law of the form
error = constant * abscissa ^ (order), .. math::
\text{Error} = \text{constant} \cdot \text{abscissa }^{\text{order}},
this function finds, in a least-squares sense, the best approximation of this function finds, in a least-squares sense, the best approximation of
constant and order for the given data set. It returns a tuple (constant, order). constant and order for the given data set. It returns a tuple (constant, order).
...@@ -22,35 +32,68 @@ def estimate_order_of_convergence(abscissae, errors): ...@@ -22,35 +32,68 @@ def estimate_order_of_convergence(abscissae, errors):
return 10**coefficients[-1], coefficients[-2] return 10**coefficients[-1], coefficients[-2]
class EOCRecorder(object): class EOCRecorder:
def __init__(self): """
self.history = [] .. automethod:: add_data_point
.. automethod:: estimate_order_of_convergence
.. automethod:: order_estimate
.. automethod:: max_error
.. automethod:: pretty_print
.. automethod:: write_gnuplot_file
"""
def __init__(self) -> None:
self.history: list[tuple[float, float]] = []
def add_data_point(self, abscissa: float, error: float) -> None:
if not (isinstance(abscissa, numbers.Number)
or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())):
raise TypeError(
f"'abscissa' is not a scalar: '{type(abscissa).__name__}'")
if not (isinstance(error, numbers.Number)
or (isinstance(error, np.ndarray) and error.shape == ())):
raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'")
def add_data_point(self, abscissa, error):
self.history.append((abscissa, error)) self.history.append((abscissa, error))
def estimate_order_of_convergence(self, gliding_mean=None): def estimate_order_of_convergence(self,
gliding_mean: int | None = None,
) -> np.ndarray:
abscissae = np.array([a for a, e in self.history]) abscissae = np.array([a for a, e in self.history])
errors = np.array([e for a, e in self.history]) errors = np.array([e for a, e in self.history])
# NOTE: in case any of the errors are exactly 0.0, which
# can give NaNs in `estimate_order_of_convergence`
emax: float = np.amax(errors)
errors += (1 if emax == 0 else emax) * np.finfo(errors.dtype).eps
size = len(abscissae) size = len(abscissae)
if gliding_mean is None: if gliding_mean is None:
gliding_mean = size gliding_mean = size
data_points = size - gliding_mean + 1 data_points = size - gliding_mean + 1
result = np.zeros((data_points, 2), float) result: np.ndarray = np.zeros((data_points, 2), float)
for i in range(data_points): for i in range(data_points):
result[i, 0], result[i, 1] = estimate_order_of_convergence( result[i, 0], result[i, 1] = estimate_order_of_convergence(
abscissae[i:i+gliding_mean], errors[i:i+gliding_mean]) abscissae[i:i+gliding_mean], errors[i:i+gliding_mean])
return result return result
def order_estimate(self): def order_estimate(self) -> float:
return self.estimate_order_of_convergence()[0, 1] return self.estimate_order_of_convergence()[0, 1]
def max_error(self): def max_error(self) -> float:
return max(err for absc, err in self.history) return max(err for absc, err in self.history)
def pretty_print(self, abscissa_label="h", error_label="Error", gliding_mean=2): def _to_table(self, *,
abscissa_label="h",
error_label="Error",
gliding_mean=2,
abscissa_format="%s",
error_format="%s",
eoc_format="%s"):
from pytools import Table from pytools import Table
tbl = Table() tbl = Table()
...@@ -58,37 +101,108 @@ class EOCRecorder(object): ...@@ -58,37 +101,108 @@ class EOCRecorder(object):
gm_eoc = self.estimate_order_of_convergence(gliding_mean) gm_eoc = self.estimate_order_of_convergence(gliding_mean)
for i, (absc, err) in enumerate(self.history): for i, (absc, err) in enumerate(self.history):
absc_str = abscissa_format % absc
err_str = error_format % err
if i < gliding_mean-1: if i < gliding_mean-1:
tbl.add_row((str(absc), str(err), "")) eoc_str = ""
else: else:
tbl.add_row((str(absc), str(err), str(gm_eoc[i-gliding_mean+1, 1]))) eoc_str = eoc_format % (gm_eoc[i - gliding_mean + 1, 1])
tbl.add_row((absc_str, err_str, eoc_str))
if len(self.history) > 1: if len(self.history) > 1:
return str(tbl) + "\n\nOverall EOC: %s" \ order = self.estimate_order_of_convergence()[0, 1]
% self.estimate_order_of_convergence()[0, 1] tbl.add_row(("Overall", "", eoc_format % order))
else:
return tbl
def pretty_print(self, *,
abscissa_label: str = "h",
error_label: str = "Error",
gliding_mean: int = 2,
abscissa_format: str = "%s",
error_format: str = "%s",
eoc_format: str = "%s",
table_type: str = "markdown") -> str:
tbl = self._to_table(
abscissa_label=abscissa_label, error_label=error_label,
abscissa_format=abscissa_format,
error_format=error_format,
eoc_format=eoc_format,
gliding_mean=gliding_mean)
if table_type == "markdown":
return tbl.github_markdown()
if table_type == "latex":
return tbl.latex()
if table_type == "ascii":
return str(tbl) return str(tbl)
if table_type == "csv":
return tbl.csv()
raise ValueError(f"unknown table type: {table_type}")
def __str__(self): def __str__(self):
return self.pretty_print() return self.pretty_print()
def write_gnuplot_file(self, filename): def write_gnuplot_file(self, filename: str) -> None:
outfile = open(filename, "w") outfile = open(filename, "w")
for absc, err in self.history: for absc, err in self.history:
outfile.write("%f %f\n" % (absc, err)) outfile.write(f"{absc:f} {err:f}\n")
result = self.estimate_order_of_convergence() result = self.estimate_order_of_convergence()
const = result[0, 0] const = result[0, 0]
order = result[0, 1] order = result[0, 1]
outfile.write("\n") outfile.write("\n")
for absc, err in self.history: for absc, _err in self.history:
outfile.write("%f %f\n" % (absc, const * absc**(-order))) outfile.write(f"{absc:f} {const * absc**(-order):f}\n")
def stringify_eocs(*eocs: EOCRecorder,
names: tuple[str, ...] | None = None,
abscissa_label: str = "h",
error_label: str = "Error",
gliding_mean: int = 2,
abscissa_format: str = "%s",
error_format: str = "%s",
eoc_format: str = "%s",
table_type: str = "markdown") -> str:
"""
:arg names: a :class:`tuple` of names to use for the *error_label* of each
*eoc*.
"""
if names is not None and len(names) < len(eocs):
raise ValueError(
f"insufficient names: got {len(names)} names for "
f"{len(eocs)} EOCRecorder instances")
if names is None:
names = tuple(f"{error_label} {i}" for i in range(len(eocs)))
from pytools import merge_tables
tbl = merge_tables(*[eoc._to_table(
abscissa_label=abscissa_label, error_label=name,
abscissa_format=abscissa_format,
error_format=error_format,
eoc_format=eoc_format,
gliding_mean=gliding_mean)
for name, eoc in zip(names, eocs, strict=True)
], skip_columns=(0,))
if table_type == "markdown":
return tbl.github_markdown()
if table_type == "latex":
return tbl.latex()
if table_type == "ascii":
return str(tbl)
if table_type == "csv":
return tbl.csv()
raise ValueError(f"unknown table type: {table_type}")
# }}} # }}}
# {{{ p convergence verifier # {{{ p convergence verifier
class PConvergenceVerifier(object): class PConvergenceVerifier:
def __init__(self): def __init__(self):
self.orders = [] self.orders = []
self.errors = [] self.errors = []
...@@ -102,7 +216,7 @@ class PConvergenceVerifier(object): ...@@ -102,7 +216,7 @@ class PConvergenceVerifier(object):
tbl = Table() tbl = Table()
tbl.add_row(("p", "error")) tbl.add_row(("p", "error"))
for p, err in zip(self.orders, self.errors): for p, err in zip(self.orders, self.errors, strict=True):
tbl.add_row((str(p), str(err))) tbl.add_row((str(p), str(err)))
return str(tbl) return str(tbl)
......
from __future__ import absolute_import from __future__ import annotations
import six from typing import IO, TYPE_CHECKING, Any
from six.moves import range, zip
from pytools import Record from pytools import Record
if TYPE_CHECKING:
from collections.abc import Callable, Iterator, Sequence
__doc__ = """
An in-memory relational database table
======================================
.. autoclass:: DataTable
"""
class Row(Record): class Row(Record):
pass pass
class DataTable: class DataTable:
"""An in-memory relational database table.""" """An in-memory relational database table.
.. automethod:: __init__
.. automethod:: copy
.. automethod:: deep_copy
.. automethod:: join
"""
def __init__(self, column_names, column_data=None): def __init__(self, column_names: Sequence[str],
column_data: list[Any] | None = None) -> None:
"""Construct a new table, with the given C{column_names}. """Construct a new table, with the given C{column_names}.
@arg column_names: An indexable of column name strings. :arg column_names: An indexable of column name strings.
@arg column_data: None or a list of tuples of the same length as :arg column_data: None or a list of tuples of the same length as
C{column_names} indicating an initial set of data. *column_names* indicating an initial set of data.
""" """
if column_data is None: if column_data is None:
self.data = [] self.data = []
...@@ -26,64 +43,64 @@ class DataTable: ...@@ -26,64 +43,64 @@ class DataTable:
self.data = column_data self.data = column_data
self.column_names = column_names self.column_names = column_names
self.column_indices = dict( self.column_indices = {
(colname, i) for i, colname in enumerate(column_names)) colname: i for i, colname in enumerate(column_names)}
if len(self.column_indices) != len(self.column_names): if len(self.column_indices) != len(self.column_names):
raise RuntimeError("non-unique column names encountered") raise RuntimeError("non-unique column names encountered")
def __bool__(self): def __bool__(self) -> bool:
return bool(self.data) return bool(self.data)
def __len__(self): def __len__(self) -> int:
return len(self.data) return len(self.data)
def __iter__(self): def __iter__(self) -> Iterator[list[Any]]:
return self.data.__iter__() return self.data.__iter__()
def __str__(self): def __str__(self) -> str:
"""Return a pretty-printed version of the table.""" """Return a pretty-printed version of the table."""
def col_width(i): def col_width(i: int) -> int:
width = len(self.column_names[i]) width = len(self.column_names[i])
if self: if self:
width = max(width, max(len(str(row[i])) for row in self.data)) width = max(width, max(len(str(row[i])) for row in self.data))
return width return width
col_widths = [col_width(i) for i in range(len(self.column_names))] col_widths = [col_width(i) for i in range(len(self.column_names))]
def format_row(row): def format_row(row: Sequence[str]) -> str:
return "|".join([str(cell).ljust(col_width) return "|".join([str(cell).ljust(col_width)
for cell, col_width in zip(row, col_widths)]) for cell, col_width in zip(row, col_widths, strict=True)])
lines = [format_row(self.column_names), lines = [format_row(self.column_names),
"+".join("-"*col_width for col_width in col_widths)] + \ "+".join("-"*col_width for col_width in col_widths)] + \
[format_row(row) for row in self.data] [format_row(row) for row in self.data]
return "\n".join(lines) return "\n".join(lines)
def insert(self, **kwargs): def insert(self, **kwargs: Any) -> None:
values = [None for i in range(len(self.column_names))] values = [None for i in range(len(self.column_names))]
for key, val in six.iteritems(kwargs): for key, val in kwargs.items():
values[self.column_indices[key]] = val values[self.column_indices[key]] = val
self.insert_row(tuple(values)) self.insert_row(tuple(values))
def insert_row(self, values): def insert_row(self, values: tuple[Any, ...]) -> None:
assert isinstance(values, tuple) assert isinstance(values, tuple)
assert len(values) == len(self.column_names) assert len(values) == len(self.column_names)
self.data.append(values) self.data.append(values)
def insert_rows(self, rows): def insert_rows(self, rows: Sequence[tuple[Any, ...]]) -> None:
for row in rows: for row in rows:
self.insert_row(row) self.insert_row(row)
def filtered(self, **kwargs): def filtered(self, **kwargs: Any) -> DataTable:
if not kwargs: if not kwargs:
return self return self
criteria = tuple( criteria = tuple(
(self.column_indices[key], value) (self.column_indices[key], value)
for key, value in six.iteritems(kwargs)) for key, value in kwargs.items())
result_data = [] result_data = []
...@@ -99,43 +116,44 @@ class DataTable: ...@@ -99,43 +116,44 @@ class DataTable:
return DataTable(self.column_names, result_data) return DataTable(self.column_names, result_data)
def get(self, **kwargs): def get(self, **kwargs: Any) -> Row:
filtered = self.filtered(**kwargs) filtered = self.filtered(**kwargs)
if not filtered: if not filtered:
raise RuntimeError("no matching entry for get()") raise RuntimeError("no matching entry for get()")
if len(filtered) > 1: if len(filtered) > 1:
raise RuntimeError("more than one matching entry for get()") raise RuntimeError("more than one matching entry for get()")
return Row(dict(list(zip(self.column_names, filtered.data[0])))) return Row(dict(zip(self.column_names, filtered.data[0], strict=True)))
def clear(self): def clear(self) -> None:
del self.data[:] del self.data[:]
def copy(self): def copy(self) -> DataTable:
"""Make a copy of the instance, but leave individual rows untouched. """Make a copy of the instance, but leave individual rows untouched.
If the rows are modified later, they will also be modified in the copy. If the rows are modified later, they will also be modified in the copy.
""" """
return DataTable(self.column_names, self.data[:]) return DataTable(self.column_names, self.data[:])
def deep_copy(self): def deep_copy(self) -> DataTable:
"""Make a copy of the instance down to the row level. """Make a copy of the instance down to the row level.
The copy's rows may be modified independently from the original. The copy's rows may be modified independently from the original.
""" """
return DataTable(self.column_names, [row[:] for row in self.data]) return DataTable(self.column_names, [row[:] for row in self.data])
def sort(self, columns, reverse=False): def sort(self, columns: Sequence[str], reverse: bool = False) -> None:
col_indices = [self.column_indices[col] for col in columns] col_indices = [self.column_indices[col] for col in columns]
def mykey(row): def mykey(row: Sequence[Any]) -> tuple[Any, ...]:
return tuple( return tuple(
row[col_index] row[col_index]
for col_index in col_indices) for col_index in col_indices)
self.data.sort(reverse=reverse, key=mykey) self.data.sort(reverse=reverse, key=mykey)
def aggregated(self, groupby, agg_column, aggregate_func): def aggregated(self, groupby: Sequence[str], agg_column: str,
aggregate_func: Callable[[Sequence[Any]], Any]) -> DataTable:
gb_indices = [self.column_indices[col] for col in groupby] gb_indices = [self.column_indices[col] for col in groupby]
agg_index = self.column_indices[agg_column] agg_index = self.column_indices[agg_column]
...@@ -144,14 +162,14 @@ class DataTable: ...@@ -144,14 +162,14 @@ class DataTable:
result_data = [] result_data = []
# to pacify pyflakes: # to pacify pyflakes:
last_values = None last_values: tuple[Any, ...] = ()
agg_values = None agg_values: list[Row] = []
for row in self.data: for row in self.data:
this_values = tuple(row[i] for i in gb_indices) this_values = tuple(row[i] for i in gb_indices)
if first or this_values != last_values: if first or this_values != last_values:
if not first: if not first:
result_data.append(last_values + (aggregate_func(agg_values),)) result_data.append((*last_values, aggregate_func(agg_values)))
agg_values = [row[agg_index]] agg_values = [row[agg_index]]
last_values = this_values last_values = this_values
...@@ -160,14 +178,15 @@ class DataTable: ...@@ -160,14 +178,15 @@ class DataTable:
agg_values.append(row[agg_index]) agg_values.append(row[agg_index])
if not first and agg_values: if not first and agg_values:
result_data.append(this_values + (aggregate_func(agg_values),)) result_data.append((*this_values, aggregate_func(agg_values)))
return DataTable( return DataTable(
[self.column_names[i] for i in gb_indices] + [agg_column], [self.column_names[i] for i in gb_indices] + [agg_column],
result_data) result_data)
def join(self, column, other_column, other_table, outer=False): def join(self, column: str, other_column: str, other_table: DataTable,
"""Return a tabled joining this and the C{other_table} on C{column}. outer: bool = False) -> DataTable:
"""Return a table joining this and the C{other_table} on C{column}.
The new table has the following columns: The new table has the following columns:
- C{column}, titled the same as in this table. - C{column}, titled the same as in this table.
...@@ -176,9 +195,9 @@ class DataTable: ...@@ -176,9 +195,9 @@ class DataTable:
Assumes both tables are sorted ascendingly by the column Assumes both tables are sorted ascendingly by the column
by which they are joined. by which they are joined.
""" # pylint:disable=too-many-locals,too-many-branches """
def without(indexable, idx): def without(indexable: tuple[str, ...], idx: int) -> tuple[str, ...]:
return indexable[:idx] + indexable[idx+1:] return indexable[:idx] + indexable[idx+1:]
this_key_idx = self.column_indices[column] this_key_idx = self.column_indices[column]
...@@ -187,9 +206,9 @@ class DataTable: ...@@ -187,9 +206,9 @@ class DataTable:
this_iter = self.data.__iter__() this_iter = self.data.__iter__()
other_iter = other_table.data.__iter__() other_iter = other_table.data.__iter__()
result_columns = [self.column_names[this_key_idx]] + \ result_columns = tuple(self.column_names[this_key_idx]) + \
without(self.column_names, this_key_idx) + \ without(tuple(self.column_names), this_key_idx) + \
without(other_table.column_names, other_key_idx) without(tuple(other_table.column_names), other_key_idx)
result_data = [] result_data = []
...@@ -225,9 +244,8 @@ class DataTable: ...@@ -225,9 +244,8 @@ class DataTable:
except StopIteration: except StopIteration:
this_over = True this_over = True
break break
else: elif outer:
if outer: this_batch = [(None,) * len(self.column_names)]
this_batch = [(None,) * len(self.column_names)]
if run_other and not other_over: if run_other and not other_over:
key = other_key key = other_key
...@@ -238,36 +256,35 @@ class DataTable: ...@@ -238,36 +256,35 @@ class DataTable:
except StopIteration: except StopIteration:
other_over = True other_over = True
break break
else: elif outer:
if outer: other_batch = [(None,) * len(other_table.column_names)]
other_batch = [(None,) * len(other_table.column_names)]
for this_batch_row in this_batch: for this_batch_row in this_batch:
for other_batch_row in other_batch: for other_batch_row in other_batch:
result_data.append((key,) result_data.append((
+ without(this_batch_row, this_key_idx) key,
+ without(other_batch_row, other_key_idx)) *without(this_batch_row, this_key_idx),
*without(other_batch_row, other_key_idx)))
if outer: if outer:
if this_over and other_over: if this_over and other_over:
break break
else: elif this_over or other_over:
if this_over or other_over: break
break
return DataTable(result_columns, result_data) return DataTable(result_columns, result_data)
def restricted(self, columns): def restricted(self, columns: Sequence[str]) -> DataTable:
col_indices = [self.column_indices[col] for col in columns] col_indices = [self.column_indices[col] for col in columns]
return DataTable(columns, return DataTable(columns,
[[row[i] for i in col_indices] for row in self.data]) [[row[i] for i in col_indices] for row in self.data])
def column_data(self, column): def column_data(self, column: str) -> list[tuple[Any, ...]]:
col_index = self.column_indices[column] col_index = self.column_indices[column]
return [row[col_index] for row in self.data] return [row[col_index] for row in self.data]
def write_csv(self, filelike, **kwargs): def write_csv(self, filelike: IO[Any], **kwargs: Any) -> None:
from csv import writer from csv import writer
csvwriter = writer(filelike, **kwargs) csvwriter = writer(filelike, **kwargs)
csvwriter.writerow(self.column_names) csvwriter.writerow(self.column_names)
......
from __future__ import absolute_import, print_function from __future__ import annotations
import sys
import six
from six.moves import input
from pytools import memoize from pytools import memoize
...@@ -13,8 +13,8 @@ def make_unique_filesystem_object(stem, extension="", directory="", ...@@ -13,8 +13,8 @@ def make_unique_filesystem_object(stem, extension="", directory="",
:param extension: needs a leading dot. :param extension: needs a leading dot.
:param directory: must not have a trailing slash. :param directory: must not have a trailing slash.
""" """
from os.path import join
import os import os
from os.path import join
if creator is None: if creator is None:
def default_creator(name): def default_creator(name):
...@@ -24,7 +24,7 @@ def make_unique_filesystem_object(stem, extension="", directory="", ...@@ -24,7 +24,7 @@ def make_unique_filesystem_object(stem, extension="", directory="",
i = 0 i = 0
while True: while True:
fname = join(directory, "%s-%d%s" % (stem, i, extension)) fname = join(directory, f"{stem}-{i}{extension}")
try: try:
return creator(fname), fname return creator(fname), fname
except OSError: except OSError:
...@@ -53,11 +53,11 @@ def open_unique_debug_file(stem, extension=""): ...@@ -53,11 +53,11 @@ def open_unique_debug_file(stem, extension=""):
# {{{ refcount debugging ------------------------------------------------------ # {{{ refcount debugging ------------------------------------------------------
class RefDebugQuit(Exception): class RefDebugQuit(Exception): # noqa: N818
pass pass
def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too-many-locals,too-many-branches,too-many-statements def refdebug(obj, top_level=True, exclude=()):
from types import FrameType from types import FrameType
def is_excluded(o): def is_excluded(o):
...@@ -99,10 +99,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too ...@@ -99,10 +99,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
else: else:
s = str(r) s = str(r)
print("%d/%d: " % (idx, len(reflist)), id(r), type(r), s) print(f"{idx}/{len(reflist)}: ", id(r), type(r), s)
if isinstance(r, dict): if isinstance(r, dict):
for k, v in six.iteritems(r): for k, v in r.items():
if v is obj: if v is obj:
print("...referred to from key", k) print("...referred to from key", k)
...@@ -111,7 +111,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too ...@@ -111,7 +111,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
response = input() response = input()
if response == "d": if response == "d":
refdebug(r, top_level=False, exclude=exclude+[reflist]) refdebug(r, top_level=False, exclude=exclude+tuple(reflist))
print_head = True print_head = True
elif response == "n": elif response == "n":
if idx + 1 < len(reflist): if idx + 1 < len(reflist):
...@@ -131,7 +131,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too ...@@ -131,7 +131,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
elif response == "r": elif response == "r":
return return
elif response == "q": elif response == "q":
raise RefDebugQuit() raise RefDebugQuit
else: else:
print("WHAT YOU SAY!!! (invalid choice)") print("WHAT YOU SAY!!! (invalid choice)")
...@@ -143,10 +143,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too ...@@ -143,10 +143,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
# {{{ interactive shell # {{{ interactive shell
def get_shell_hist_filename(): def get_shell_hist_filename() -> str:
import os import os
_home = os.environ.get('HOME', '/')
return os.path.join(_home, ".pytools-debug-shell-history") return os.path.expanduser(os.path.join("~", ".pytools-debug-shell-history"))
def setup_readline(): def setup_readline():
...@@ -156,12 +156,12 @@ def setup_readline(): ...@@ -156,12 +156,12 @@ def setup_readline():
try: try:
readline.read_history_file(hist_filename) readline.read_history_file(hist_filename)
except Exception: # pylint:disable=broad-except except Exception: # pylint:disable=broad-except
# http://docs.python.org/3/howto/pyporting.html#capturing-the-currently-raised-exception # noqa: E501 pylint:disable=line-too-long # http://docs.python.org/3/howto/pyporting.html#capturing-the-currently-raised-exception
import sys import sys
e = sys.exc_info()[1] e = sys.exc_info()[1]
from warnings import warn from warnings import warn
warn("Error opening readline history file: %s" % e) warn(f"Error opening readline history file: {e}", stacklevel=2)
readline.parse_and_bind("tab: complete") readline.parse_and_bind("tab: complete")
...@@ -216,4 +216,27 @@ def shell(locals_=None, globals_=None): ...@@ -216,4 +216,27 @@ def shell(locals_=None, globals_=None):
# }}} # }}}
# {{{ estimate memory usage
def estimate_memory_usage(root, seen_ids=None):
if seen_ids is None:
seen_ids = set()
id_root = id(root)
if id_root in seen_ids:
return 0
seen_ids.add(id_root)
result = sys.getsizeof(root)
from gc import get_referents
for ref in get_referents(root):
result += estimate_memory_usage(ref, seen_ids=seen_ids)
return result
# }}}
# vim: foldmethod=marker # vim: foldmethod=marker
from __future__ import absolute_import
# Python decorator module
# by Michele Simionato
# http://www.phyast.pitt.edu/~micheles/python/
## The basic trick is to generate the source code for the decorated function
## with the right signature and to evaluate it.
## Uncomment the statement 'print >> sys.stderr, func_src' in _decorate
## to understand what is going on.
__all__ = ["decorator", "update_wrapper", "getinfo"]
import inspect
def getinfo(func):
"""
Returns an info dictionary containing:
- name (the name of the function : str)
- argnames (the names of the arguments : list)
- defaults (the values of the default arguments : tuple)
- signature (the signature : str)
- doc (the docstring : str)
- module (the module name : str)
- dict (the function __dict__ : str)
>>> def f(self, x=1, y=2, *args, **kw): pass
>>> info = getinfo(f)
>>> info["name"]
'f'
>>> info["argnames"]
['self', 'x', 'y', 'args', 'kw']
>>> info["defaults"]
(1, 2)
>>> info["signature"]
'self, x, y, *args, **kw'
"""
assert inspect.ismethod(func) or inspect.isfunction(func)
regargs, varargs, varkwargs, defaults = inspect.getargspec(func)
argnames = list(regargs)
if varargs:
argnames.append(varargs)
if varkwargs:
argnames.append(varkwargs)
signature = inspect.formatargspec(regargs, varargs, varkwargs, defaults,
formatvalue=lambda value: "")[1:-1]
return dict(name=func.__name__, argnames=argnames, signature=signature,
defaults = func.__defaults__, doc=func.__doc__,
module=func.__module__, dict=func.__dict__,
globals=func.__globals__, closure=func.__closure__)
def update_wrapper(wrapper, wrapped, create=False):
"""
An improvement over functools.update_wrapper. By default it works the
same, but if the 'create' flag is set, generates a copy of the wrapper
with the right signature and update the copy, not the original.
Moreovoer, 'wrapped' can be a dictionary with keys 'name', 'doc', 'module',
'dict', 'defaults'.
"""
if isinstance(wrapped, dict):
infodict = wrapped
else: # assume wrapped is a function
infodict = getinfo(wrapped)
assert not '_wrapper_' in infodict["argnames"], \
'"_wrapper_" is a reserved argument name!'
if create: # create a brand new wrapper with the right signature
src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict
# import sys; print >> sys.stderr, src # for debugging purposes
wrapper = eval(src, dict(_wrapper_=wrapper))
try:
wrapper.__name__ = infodict['name']
except: # Python version < 2.4
pass
wrapper.__doc__ = infodict['doc']
wrapper.__module__ = infodict['module']
wrapper.__dict__.update(infodict['dict'])
wrapper.__defaults__ = infodict['defaults']
return wrapper
# the real meat is here
def _decorator(caller, func):
if not (inspect.ismethod(func) or inspect.isfunction(func)):
# skip all the fanciness, just do what works
return lambda *args, **kwargs: caller(func, *args, **kwargs)
infodict = getinfo(func)
argnames = infodict['argnames']
assert not ('_call_' in argnames or '_func_' in argnames), \
'You cannot use _call_ or _func_ as argument names!'
src = "lambda %(signature)s: _call_(_func_, %(signature)s)" % infodict
dec_func = eval(src, dict(_func_=func, _call_=caller))
return update_wrapper(dec_func, func)
def decorator(caller, func=None):
"""
General purpose decorator factory: takes a caller function as
input and returns a decorator with the same attributes.
A caller function is any function like this::
def caller(func, *args, **kw):
# do something
return func(*args, **kw)
Here is an example of usage:
>>> @decorator
... def chatty(f, *args, **kw):
... print "Calling %r" % f.__name__
... return f(*args, **kw)
>>> chatty.__name__
'chatty'
>>> @chatty
... def f(): pass
...
>>> f()
Calling 'f'
For sake of convenience, the decorator factory can also be called with
two arguments. In this casem ``decorator(caller, func)`` is just a
shortcut for ``decorator(caller)(func)``.
"""
from warnings import warn
warn("pytools.decorator is deprecated and will be removed in pytools 12. "
"Use the 'decorator' module directly instead.",
DeprecationWarning, stacklevel=2)
if func is None: # return a decorator function
return update_wrapper(lambda f : _decorator(caller, f), caller)
else: # return a decorated function
return _decorator(caller, func)
if __name__ == "__main__":
import doctest; doctest.testmod()
####################### LEGALESE ##################################
## Redistributions of source code must retain the above copyright
## notice, this list of conditions and the following disclaimer.
## Redistributions in bytecode form must reproduce the above copyright
## notice, this list of conditions and the following disclaimer in
## the documentation and/or other materials provided with the
## distribution.
## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
## DAMAGE.
This diff is collapsed.
from __future__ import annotations
__copyright__ = """
Copyright (C) 2013 Andreas Kloeckner
Copyright (C) 2014 Matt Wala
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
Dot helper functions
====================
.. autofunction:: dot_escape
.. autofunction:: show_dot
"""
import html
import logging
import os
logger = logging.getLogger(__name__)
# {{{ graphviz / dot interactive show
def dot_escape(s: str) -> str:
"""
Escape the string *s* for compatibility with the
`dot <http://graphviz.org/>`__ language, particularly
backslashes and HTML tags.
:arg s: The input string to escape.
:returns: *s* with special characters escaped.
"""
# "\" and HTML are significant in graphviz.
return html.escape(s.replace("\\", "\\\\"))
def show_dot(dot_code: str, output_to: str | None = None) -> str | None:
"""
Visualize the graph represented by *dot_code*.
:arg dot_code: An instance of :class:`str` in the `dot <http://graphviz.org/>`__
language to visualize.
:arg output_to: An instance of :class:`str` that can be one of:
- ``"xwindow"`` to visualize the graph as an
`X window <https://en.wikipedia.org/wiki/X_Window_System>`_.
- ``"browser"`` to visualize the graph as an SVG file in the
system's default web-browser.
- ``"svg"`` to store the dot code as an SVG file on the file system.
Returns the path to the generated SVG file.
Defaults to ``"xwindow"`` if X11 support is present, otherwise defaults
to ``"browser"``.
:returns: Depends on *output_to*. If ``"svg"``, returns the path to the
generated SVG file, otherwise returns ``None``.
"""
import subprocess
from tempfile import mkdtemp
temp_dir = mkdtemp(prefix="tmp_pytools_dot")
dot_file_name = "code.dot"
from os.path import join
with open(join(temp_dir, dot_file_name), "w") as dotf:
dotf.write(dot_code)
# {{{ preprocess 'output_to'
if output_to is None:
with subprocess.Popen(["dot", "-T?"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
) as proc:
assert proc.stderr, ("Could not execute the 'dot' program. "
"Please install the 'graphviz' package and "
"make sure it is in your $PATH.")
supported_formats = proc.stderr.read().decode()
if " x11 " in supported_formats and "DISPLAY" in os.environ:
output_to = "xwindow"
else:
output_to = "browser"
# }}}
if output_to == "xwindow":
subprocess.check_call(["dot", "-Tx11", dot_file_name], cwd=temp_dir)
elif output_to in ["browser", "svg"]:
svg_file_name = "code.svg"
subprocess.check_call(["dot", "-Tsvg", "-o", svg_file_name, dot_file_name],
cwd=temp_dir)
full_svg_file_name = join(temp_dir, svg_file_name)
logger.info("show_dot: svg written to '%s'", full_svg_file_name)
if output_to == "svg":
return full_svg_file_name
assert output_to == "browser"
from webbrowser import open as browser_open
browser_open("file://" + full_svg_file_name)
else:
raise ValueError("`output_to` can be one of 'xwindow', 'browser', or 'svg',"
f" got '{output_to}'")
return None
# }}}
# vim: foldmethod=marker
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""See pytools.prefork for this module's reason for being.""" """See pytools.prefork for this module's reason for being."""
from __future__ import absolute_import from __future__ import annotations
import mpi4py.rc # pylint:disable=import-error import mpi4py.rc # pylint:disable=import-error
mpi4py.rc.initialize = False mpi4py.rc.initialize = False
from mpi4py.MPI import * # noqa pylint:disable=wildcard-import,wrong-import-position from mpi4py.MPI import * # noqa pylint:disable=wildcard-import,wrong-import-position
import pytools.prefork # pylint:disable=wrong-import-position import pytools.prefork # pylint:disable=wrong-import-position
pytools.prefork.enable_prefork() pytools.prefork.enable_prefork()
if Is_initialized(): # noqa pylint:disable=undefined-variable # pylint: disable-next=undefined-variable
if Is_initialized(): # type: ignore[name-defined,unused-ignore] # noqa
raise RuntimeError("MPI already initialized before MPI wrapper import") raise RuntimeError("MPI already initialized before MPI wrapper import")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.