Skip to content
#! /bin/sh
rsync --verbose --archive --delete _build/html/* doc-upload:doc/pytools
rsync --verbose --archive --delete _build/html/ doc-upload:doc/pytools
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "pytools"
version = "2024.1.21"
description = "A collection of tools for Python"
readme = "README.rst"
license = { text = "MIT" }
authors = [
{ name = "Andreas Kloeckner", email = "inform@tiker.net" },
]
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Other Audience",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Scientific/Engineering :: Visualization",
"Topic :: Software Development :: Libraries",
"Topic :: Utilities",
]
dependencies = [
"platformdirs>=2.2",
# for dataclass_transform with frozen_default
"typing-extensions>=4.5",
]
[project.optional-dependencies]
numpy = [
"numpy>=1.6",
]
test = [
"mypy",
"pytest",
"ruff",
]
siphash = [
"siphash24>=1.6",
]
[project.urls]
Documentation = "https://documen.tician.de/pytools/"
Homepage = "https://github.com/inducer/pytools/"
[tool.hatch.build.targets.sdist]
exclude = [
"/.git*",
"/doc/_build",
"/.editorconfig",
"/run-*.sh",
]
[tool.ruff]
preview = true
[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear
"C", # flake8-comprehensions
"E", # pycodestyle
"F", # pyflakes
"G", # flake8-logging-format
"I", # flake8-isort
"N", # pep8-naming
"NPY", # numpy
"Q", # flake8-quotes
"UP", # pyupgrade
"RUF", # ruff
"W", # pycodestyle
"TC",
]
extend-ignore = [
"C90", # McCabe complexity
"E221", # multiple spaces before operator
"E226", # missing whitespace around arithmetic operator
"E402", # module-level import not at top of file
"UP031", # use f-strings instead of %
"UP032", # use f-strings instead of .format
]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
inline-quotes = "double"
multiline-quotes = "double"
[tool.ruff.lint.isort]
combine-as-imports = true
known-local-folder = [
"pytools",
]
lines-after-imports = 2
required-imports = ["from __future__ import annotations"]
[tool.ruff.lint.pep8-naming]
extend-ignore-names = ["update_for_*"]
[tool.mypy]
python_version = "3.10"
ignore_missing_imports = true
warn_unused_ignores = true
# TODO: enable this at some point
# check_untyped_defs = true
[tool.typos.default]
extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:\\s*disable-line$"
]
This diff is collapsed.
from __future__ import absolute_import
import six
from __future__ import annotations
def _cp(src, dest):
......@@ -22,7 +21,7 @@ def get_timestamp():
return datetime.now().strftime("%Y-%m-%d-%H%M%S")
class BatchJob(object):
class BatchJob:
def __init__(self, moniker, main_file, aux_files=(), timestamp=None):
import os
import os.path
......@@ -44,10 +43,9 @@ class BatchJob(object):
os.makedirs(self.path)
runscript = open("%s/run.sh" % self.path, "w")
runscript = open(f"{self.path}/run.sh", "w")
import sys
runscript.write("%s %s setup.cpy"
% (sys.executable, main_file))
runscript.write(f"{sys.executable} {main_file} setup.cpy")
runscript.close()
from os.path import basename
......@@ -65,7 +63,7 @@ class BatchJob(object):
setup.close()
class INHERIT(object): # noqa
class INHERIT:
pass
......@@ -80,20 +78,20 @@ class GridEngineJob(BatchJob):
from os import getenv
env = dict(env)
for var, value in six.iteritems(env):
for var, value in env.items():
if value is INHERIT:
value = getenv(var)
args += ["-v", "%s=%s" % (var, value)]
args += ["-v", f"{var}={value}"]
if memory_megs is not None:
args.extend(["-l", "mem=%d" % memory_megs])
args.extend(["-l", f"mem={memory_megs}"])
args.extend(extra_args)
subproc = Popen(["qsub"] + args + ["run.sh"], cwd=self.path)
subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path)
if subproc.wait() != 0:
raise RuntimeError("Process submission of %s failed" % self.moniker)
raise RuntimeError(f"Process submission of {self.moniker} failed")
class PBSJob(BatchJob):
......@@ -106,32 +104,31 @@ class PBSJob(BatchJob):
]
if memory_megs is not None:
args.extend(["-l", "pmem=%dmb" % memory_megs])
args.extend(["-l", f"pmem={memory_megs}mb"])
from os import getenv
env = dict(env)
for var, value in six.iteritems(env):
for var, value in env.items():
if value is INHERIT:
value = getenv(var)
args += ["-v", "%s=%s" % (var, value)]
args += ["-v", f"{var}={value}"]
args.extend(extra_args)
subproc = Popen(["qsub"] + args + ["run.sh"], cwd=self.path)
subproc = Popen(["qsub", *args, "run.sh"], cwd=self.path)
if subproc.wait() != 0:
raise RuntimeError("Process submission of %s failed" % self.moniker)
raise RuntimeError(f"Process submission of {self.moniker} failed")
def guess_job_class():
from subprocess import Popen, PIPE, STDOUT
from subprocess import PIPE, STDOUT, Popen
qstat_helplines = Popen(["qstat", "--help"],
stdout=PIPE, stderr=STDOUT).communicate()[0].split("\n")
if qstat_helplines[0].startswith("GE"):
return GridEngineJob
else:
return PBSJob
return PBSJob
class ConstructorPlaceholder:
......@@ -147,11 +144,11 @@ class ConstructorPlaceholder:
return self.kwargs[name]
def __str__(self):
return "%s(%s)" % (self.classname,
return "{}({})".format(self.classname,
",".join(
[str(arg) for arg in self.args]
+ ["%s=%s" % (kw, repr(val))
for kw, val in six.iteritems(self.kwargs)]
+ [f"{kw}={val!r}"
for kw, val in self.kwargs.items()]
)
)
__repr__ = __str__
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
Tools for Source Code Generation
================================
.. autoclass:: CodeGenerator
.. autoclass:: Indentation
.. autofunction:: remove_common_indentation
"""
from typing import Any
# {{{ code generation
# loosely based on
# http://effbot.org/zone/python-code-generator.htm
class CodeGenerator:
"""Language-agnostic functionality for source code generation.
.. automethod:: extend
.. automethod:: get
.. automethod:: add_to_preamble
.. automethod:: __call__
.. automethod:: indent
.. automethod:: dedent
"""
def __init__(self) -> None:
self.preamble: list[str] = []
self.code: list[str] = []
self.level = 0
self.indent_amount = 4
def extend(self, sub_generator: CodeGenerator) -> None:
for line in sub_generator.code:
self.code.append(" "*(self.indent_amount*self.level) + line)
def get(self) -> str:
result = "\n".join(self.code)
if self.preamble:
result = "\n".join(self.preamble) + "\n" + result
return result
def add_to_preamble(self, s: str) -> None:
self.preamble.append(s)
def __call__(self, s: str) -> None:
if not s.strip():
self.code.append("")
else:
if "\n" in s:
s = remove_common_indentation(s)
for line in s.split("\n"):
self.code.append(" "*(self.indent_amount*self.level) + line)
def indent(self) -> None:
self.level += 1
def dedent(self) -> None:
if self.level == 0:
raise RuntimeError("cannot decrease indentation level")
self.level -= 1
class Indentation:
"""A context manager for indentation for use with :class:`CodeGenerator`.
.. attribute:: generator
.. automethod:: __enter__
.. automethod:: __exit__
"""
def __init__(self, generator: CodeGenerator):
self.generator = generator
def __enter__(self) -> None:
self.generator.indent()
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.generator.dedent()
# }}}
# {{{ remove common indentation
def remove_common_indentation(code: str, require_leading_newline: bool = True):
r"""Remove leading indentation from one or more lines of code.
Removes an amount of indentation equal to the indentation level of the first
nonempty line in *code*.
:param code: Input string.
:param require_leading_newline: If *True*, only remove indentation if *code*
starts with ``\n``.
:returns: A copy of *code* stripped of leading common indentation.
"""
if "\n" not in code:
return code
if require_leading_newline and not code.startswith("\n"):
return code
lines = code.split("\n")
while lines[0].strip() == "":
lines.pop(0)
while lines[-1].strip() == "":
lines.pop(-1)
if lines:
base_indent = 0
while lines[0][base_indent] in " \t":
base_indent += 1
for line in lines[1:]:
if line[:base_indent].strip():
raise ValueError("inconsistent indentation")
return "\n".join(line[base_indent:] for line in lines)
# }}}
# vim: foldmethod=marker
from __future__ import absolute_import
"""
.. autofunction:: estimate_order_of_convergence
.. autoclass:: EOCRecorder
.. autofunction:: stringify_eocs
.. autoclass:: PConvergenceVerifier
"""
from __future__ import annotations
import numbers
import numpy as np
from six.moves import range
from six.moves import zip
# {{{ eoc estimation --------------------------------------------------------------
def estimate_order_of_convergence(abscissae, errors):
"""Assuming that abscissae and errors are connected by a law of the form
r"""Assuming that abscissae and errors are connected by a law of the form
error = constant * abscissa ^ (order),
.. math::
\text{Error} = \text{constant} \cdot \text{abscissa }^{\text{order}},
this function finds, in a least-squares sense, the best approximation of
constant and order for the given data set. It returns a tuple (constant, order).
......@@ -22,35 +32,68 @@ def estimate_order_of_convergence(abscissae, errors):
return 10**coefficients[-1], coefficients[-2]
class EOCRecorder(object):
def __init__(self):
self.history = []
class EOCRecorder:
"""
.. automethod:: add_data_point
.. automethod:: estimate_order_of_convergence
.. automethod:: order_estimate
.. automethod:: max_error
.. automethod:: pretty_print
.. automethod:: write_gnuplot_file
"""
def __init__(self) -> None:
self.history: list[tuple[float, float]] = []
def add_data_point(self, abscissa: float, error: float) -> None:
if not (isinstance(abscissa, numbers.Number)
or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())):
raise TypeError(
f"'abscissa' is not a scalar: '{type(abscissa).__name__}'")
if not (isinstance(error, numbers.Number)
or (isinstance(error, np.ndarray) and error.shape == ())):
raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'")
def add_data_point(self, abscissa, error):
self.history.append((abscissa, error))
def estimate_order_of_convergence(self, gliding_mean=None):
def estimate_order_of_convergence(self,
gliding_mean: int | None = None,
) -> np.ndarray:
abscissae = np.array([a for a, e in self.history])
errors = np.array([e for a, e in self.history])
# NOTE: in case any of the errors are exactly 0.0, which
# can give NaNs in `estimate_order_of_convergence`
emax: float = np.amax(errors)
errors += (1 if emax == 0 else emax) * np.finfo(errors.dtype).eps
size = len(abscissae)
if gliding_mean is None:
gliding_mean = size
data_points = size - gliding_mean + 1
result = np.zeros((data_points, 2), float)
result: np.ndarray = np.zeros((data_points, 2), float)
for i in range(data_points):
result[i, 0], result[i, 1] = estimate_order_of_convergence(
abscissae[i:i+gliding_mean], errors[i:i+gliding_mean])
return result
def order_estimate(self):
def order_estimate(self) -> float:
return self.estimate_order_of_convergence()[0, 1]
def max_error(self):
def max_error(self) -> float:
return max(err for absc, err in self.history)
def pretty_print(self, abscissa_label="h", error_label="Error", gliding_mean=2):
def _to_table(self, *,
abscissa_label="h",
error_label="Error",
gliding_mean=2,
abscissa_format="%s",
error_format="%s",
eoc_format="%s"):
from pytools import Table
tbl = Table()
......@@ -58,37 +101,108 @@ class EOCRecorder(object):
gm_eoc = self.estimate_order_of_convergence(gliding_mean)
for i, (absc, err) in enumerate(self.history):
absc_str = abscissa_format % absc
err_str = error_format % err
if i < gliding_mean-1:
tbl.add_row((str(absc), str(err), ""))
eoc_str = ""
else:
tbl.add_row((str(absc), str(err), str(gm_eoc[i-gliding_mean+1, 1])))
eoc_str = eoc_format % (gm_eoc[i - gliding_mean + 1, 1])
tbl.add_row((absc_str, err_str, eoc_str))
if len(self.history) > 1:
return str(tbl) + "\n\nOverall EOC: %s" \
% self.estimate_order_of_convergence()[0, 1]
else:
order = self.estimate_order_of_convergence()[0, 1]
tbl.add_row(("Overall", "", eoc_format % order))
return tbl
def pretty_print(self, *,
abscissa_label: str = "h",
error_label: str = "Error",
gliding_mean: int = 2,
abscissa_format: str = "%s",
error_format: str = "%s",
eoc_format: str = "%s",
table_type: str = "markdown") -> str:
tbl = self._to_table(
abscissa_label=abscissa_label, error_label=error_label,
abscissa_format=abscissa_format,
error_format=error_format,
eoc_format=eoc_format,
gliding_mean=gliding_mean)
if table_type == "markdown":
return tbl.github_markdown()
if table_type == "latex":
return tbl.latex()
if table_type == "ascii":
return str(tbl)
if table_type == "csv":
return tbl.csv()
raise ValueError(f"unknown table type: {table_type}")
def __str__(self):
return self.pretty_print()
def write_gnuplot_file(self, filename):
def write_gnuplot_file(self, filename: str) -> None:
outfile = open(filename, "w")
for absc, err in self.history:
outfile.write("%f %f\n" % (absc, err))
outfile.write(f"{absc:f} {err:f}\n")
result = self.estimate_order_of_convergence()
const = result[0, 0]
order = result[0, 1]
outfile.write("\n")
for absc, err in self.history:
outfile.write("%f %f\n" % (absc, const * absc**(-order)))
for absc, _err in self.history:
outfile.write(f"{absc:f} {const * absc**(-order):f}\n")
def stringify_eocs(*eocs: EOCRecorder,
names: tuple[str, ...] | None = None,
abscissa_label: str = "h",
error_label: str = "Error",
gliding_mean: int = 2,
abscissa_format: str = "%s",
error_format: str = "%s",
eoc_format: str = "%s",
table_type: str = "markdown") -> str:
"""
:arg names: a :class:`tuple` of names to use for the *error_label* of each
*eoc*.
"""
if names is not None and len(names) < len(eocs):
raise ValueError(
f"insufficient names: got {len(names)} names for "
f"{len(eocs)} EOCRecorder instances")
if names is None:
names = tuple(f"{error_label} {i}" for i in range(len(eocs)))
from pytools import merge_tables
tbl = merge_tables(*[eoc._to_table(
abscissa_label=abscissa_label, error_label=name,
abscissa_format=abscissa_format,
error_format=error_format,
eoc_format=eoc_format,
gliding_mean=gliding_mean)
for name, eoc in zip(names, eocs, strict=True)
], skip_columns=(0,))
if table_type == "markdown":
return tbl.github_markdown()
if table_type == "latex":
return tbl.latex()
if table_type == "ascii":
return str(tbl)
if table_type == "csv":
return tbl.csv()
raise ValueError(f"unknown table type: {table_type}")
# }}}
# {{{ p convergence verifier
class PConvergenceVerifier(object):
class PConvergenceVerifier:
def __init__(self):
self.orders = []
self.errors = []
......@@ -102,7 +216,7 @@ class PConvergenceVerifier(object):
tbl = Table()
tbl.add_row(("p", "error"))
for p, err in zip(self.orders, self.errors):
for p, err in zip(self.orders, self.errors, strict=True):
tbl.add_row((str(p), str(err)))
return str(tbl)
......
from __future__ import absolute_import
from __future__ import annotations
import six
from six.moves import range, zip
from typing import IO, TYPE_CHECKING, Any
from pytools import Record
if TYPE_CHECKING:
from collections.abc import Callable, Iterator, Sequence
__doc__ = """
An in-memory relational database table
======================================
.. autoclass:: DataTable
"""
class Row(Record):
pass
class DataTable:
"""An in-memory relational database table."""
"""An in-memory relational database table.
.. automethod:: __init__
.. automethod:: copy
.. automethod:: deep_copy
.. automethod:: join
"""
def __init__(self, column_names, column_data=None):
def __init__(self, column_names: Sequence[str],
column_data: list[Any] | None = None) -> None:
"""Construct a new table, with the given C{column_names}.
@arg column_names: An indexable of column name strings.
@arg column_data: None or a list of tuples of the same length as
C{column_names} indicating an initial set of data.
:arg column_names: An indexable of column name strings.
:arg column_data: None or a list of tuples of the same length as
*column_names* indicating an initial set of data.
"""
if column_data is None:
self.data = []
......@@ -26,64 +43,64 @@ class DataTable:
self.data = column_data
self.column_names = column_names
self.column_indices = dict(
(colname, i) for i, colname in enumerate(column_names))
self.column_indices = {
colname: i for i, colname in enumerate(column_names)}
if len(self.column_indices) != len(self.column_names):
raise RuntimeError("non-unique column names encountered")
def __bool__(self):
def __bool__(self) -> bool:
return bool(self.data)
def __len__(self):
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
def __iter__(self) -> Iterator[list[Any]]:
return self.data.__iter__()
def __str__(self):
def __str__(self) -> str:
"""Return a pretty-printed version of the table."""
def col_width(i):
def col_width(i: int) -> int:
width = len(self.column_names[i])
if self:
width = max(width, max(len(str(row[i])) for row in self.data))
return width
col_widths = [col_width(i) for i in range(len(self.column_names))]
def format_row(row):
def format_row(row: Sequence[str]) -> str:
return "|".join([str(cell).ljust(col_width)
for cell, col_width in zip(row, col_widths)])
for cell, col_width in zip(row, col_widths, strict=True)])
lines = [format_row(self.column_names),
"+".join("-"*col_width for col_width in col_widths)] + \
[format_row(row) for row in self.data]
return "\n".join(lines)
def insert(self, **kwargs):
def insert(self, **kwargs: Any) -> None:
values = [None for i in range(len(self.column_names))]
for key, val in six.iteritems(kwargs):
for key, val in kwargs.items():
values[self.column_indices[key]] = val
self.insert_row(tuple(values))
def insert_row(self, values):
def insert_row(self, values: tuple[Any, ...]) -> None:
assert isinstance(values, tuple)
assert len(values) == len(self.column_names)
self.data.append(values)
def insert_rows(self, rows):
def insert_rows(self, rows: Sequence[tuple[Any, ...]]) -> None:
for row in rows:
self.insert_row(row)
def filtered(self, **kwargs):
def filtered(self, **kwargs: Any) -> DataTable:
if not kwargs:
return self
criteria = tuple(
(self.column_indices[key], value)
for key, value in six.iteritems(kwargs))
for key, value in kwargs.items())
result_data = []
......@@ -99,43 +116,44 @@ class DataTable:
return DataTable(self.column_names, result_data)
def get(self, **kwargs):
def get(self, **kwargs: Any) -> Row:
filtered = self.filtered(**kwargs)
if not filtered:
raise RuntimeError("no matching entry for get()")
if len(filtered) > 1:
raise RuntimeError("more than one matching entry for get()")
return Row(dict(list(zip(self.column_names, filtered.data[0]))))
return Row(dict(zip(self.column_names, filtered.data[0], strict=True)))
def clear(self):
def clear(self) -> None:
del self.data[:]
def copy(self):
def copy(self) -> DataTable:
"""Make a copy of the instance, but leave individual rows untouched.
If the rows are modified later, they will also be modified in the copy.
"""
return DataTable(self.column_names, self.data[:])
def deep_copy(self):
def deep_copy(self) -> DataTable:
"""Make a copy of the instance down to the row level.
The copy's rows may be modified independently from the original.
"""
return DataTable(self.column_names, [row[:] for row in self.data])
def sort(self, columns, reverse=False):
def sort(self, columns: Sequence[str], reverse: bool = False) -> None:
col_indices = [self.column_indices[col] for col in columns]
def mykey(row):
def mykey(row: Sequence[Any]) -> tuple[Any, ...]:
return tuple(
row[col_index]
for col_index in col_indices)
self.data.sort(reverse=reverse, key=mykey)
def aggregated(self, groupby, agg_column, aggregate_func):
def aggregated(self, groupby: Sequence[str], agg_column: str,
aggregate_func: Callable[[Sequence[Any]], Any]) -> DataTable:
gb_indices = [self.column_indices[col] for col in groupby]
agg_index = self.column_indices[agg_column]
......@@ -144,14 +162,14 @@ class DataTable:
result_data = []
# to pacify pyflakes:
last_values = None
agg_values = None
last_values: tuple[Any, ...] = ()
agg_values: list[Row] = []
for row in self.data:
this_values = tuple(row[i] for i in gb_indices)
if first or this_values != last_values:
if not first:
result_data.append(last_values + (aggregate_func(agg_values),))
result_data.append((*last_values, aggregate_func(agg_values)))
agg_values = [row[agg_index]]
last_values = this_values
......@@ -160,14 +178,15 @@ class DataTable:
agg_values.append(row[agg_index])
if not first and agg_values:
result_data.append(this_values + (aggregate_func(agg_values),))
result_data.append((*this_values, aggregate_func(agg_values)))
return DataTable(
[self.column_names[i] for i in gb_indices] + [agg_column],
result_data)
def join(self, column, other_column, other_table, outer=False):
"""Return a tabled joining this and the C{other_table} on C{column}.
def join(self, column: str, other_column: str, other_table: DataTable,
outer: bool = False) -> DataTable:
"""Return a table joining this and the C{other_table} on C{column}.
The new table has the following columns:
- C{column}, titled the same as in this table.
......@@ -176,9 +195,9 @@ class DataTable:
Assumes both tables are sorted ascendingly by the column
by which they are joined.
""" # pylint:disable=too-many-locals,too-many-branches
"""
def without(indexable, idx):
def without(indexable: tuple[str, ...], idx: int) -> tuple[str, ...]:
return indexable[:idx] + indexable[idx+1:]
this_key_idx = self.column_indices[column]
......@@ -187,9 +206,9 @@ class DataTable:
this_iter = self.data.__iter__()
other_iter = other_table.data.__iter__()
result_columns = [self.column_names[this_key_idx]] + \
without(self.column_names, this_key_idx) + \
without(other_table.column_names, other_key_idx)
result_columns = tuple(self.column_names[this_key_idx]) + \
without(tuple(self.column_names), this_key_idx) + \
without(tuple(other_table.column_names), other_key_idx)
result_data = []
......@@ -225,9 +244,8 @@ class DataTable:
except StopIteration:
this_over = True
break
else:
if outer:
this_batch = [(None,) * len(self.column_names)]
elif outer:
this_batch = [(None,) * len(self.column_names)]
if run_other and not other_over:
key = other_key
......@@ -238,36 +256,35 @@ class DataTable:
except StopIteration:
other_over = True
break
else:
if outer:
other_batch = [(None,) * len(other_table.column_names)]
elif outer:
other_batch = [(None,) * len(other_table.column_names)]
for this_batch_row in this_batch:
for other_batch_row in other_batch:
result_data.append((key,)
+ without(this_batch_row, this_key_idx)
+ without(other_batch_row, other_key_idx))
result_data.append((
key,
*without(this_batch_row, this_key_idx),
*without(other_batch_row, other_key_idx)))
if outer:
if this_over and other_over:
break
else:
if this_over or other_over:
break
elif this_over or other_over:
break
return DataTable(result_columns, result_data)
def restricted(self, columns):
def restricted(self, columns: Sequence[str]) -> DataTable:
col_indices = [self.column_indices[col] for col in columns]
return DataTable(columns,
[[row[i] for i in col_indices] for row in self.data])
def column_data(self, column):
def column_data(self, column: str) -> list[tuple[Any, ...]]:
col_index = self.column_indices[column]
return [row[col_index] for row in self.data]
def write_csv(self, filelike, **kwargs):
def write_csv(self, filelike: IO[Any], **kwargs: Any) -> None:
from csv import writer
csvwriter = writer(filelike, **kwargs)
csvwriter.writerow(self.column_names)
......
from __future__ import absolute_import, print_function
from __future__ import annotations
import sys
import six
from six.moves import input
from pytools import memoize
......@@ -13,8 +13,8 @@ def make_unique_filesystem_object(stem, extension="", directory="",
:param extension: needs a leading dot.
:param directory: must not have a trailing slash.
"""
from os.path import join
import os
from os.path import join
if creator is None:
def default_creator(name):
......@@ -24,7 +24,7 @@ def make_unique_filesystem_object(stem, extension="", directory="",
i = 0
while True:
fname = join(directory, "%s-%d%s" % (stem, i, extension))
fname = join(directory, f"{stem}-{i}{extension}")
try:
return creator(fname), fname
except OSError:
......@@ -53,11 +53,11 @@ def open_unique_debug_file(stem, extension=""):
# {{{ refcount debugging ------------------------------------------------------
class RefDebugQuit(Exception):
class RefDebugQuit(Exception): # noqa: N818
pass
def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too-many-locals,too-many-branches,too-many-statements
def refdebug(obj, top_level=True, exclude=()):
from types import FrameType
def is_excluded(o):
......@@ -99,10 +99,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
else:
s = str(r)
print("%d/%d: " % (idx, len(reflist)), id(r), type(r), s)
print(f"{idx}/{len(reflist)}: ", id(r), type(r), s)
if isinstance(r, dict):
for k, v in six.iteritems(r):
for k, v in r.items():
if v is obj:
print("...referred to from key", k)
......@@ -111,7 +111,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
response = input()
if response == "d":
refdebug(r, top_level=False, exclude=exclude+[reflist])
refdebug(r, top_level=False, exclude=exclude+tuple(reflist))
print_head = True
elif response == "n":
if idx + 1 < len(reflist):
......@@ -131,7 +131,7 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
elif response == "r":
return
elif response == "q":
raise RefDebugQuit()
raise RefDebugQuit
else:
print("WHAT YOU SAY!!! (invalid choice)")
......@@ -143,10 +143,10 @@ def refdebug(obj, top_level=True, exclude=()): # noqa: E501 pylint:disable=too
# {{{ interactive shell
def get_shell_hist_filename():
def get_shell_hist_filename() -> str:
import os
_home = os.environ.get('HOME', '/')
return os.path.join(_home, ".pytools-debug-shell-history")
return os.path.expanduser(os.path.join("~", ".pytools-debug-shell-history"))
def setup_readline():
......@@ -156,12 +156,12 @@ def setup_readline():
try:
readline.read_history_file(hist_filename)
except Exception: # pylint:disable=broad-except
# http://docs.python.org/3/howto/pyporting.html#capturing-the-currently-raised-exception # noqa: E501 pylint:disable=line-too-long
# http://docs.python.org/3/howto/pyporting.html#capturing-the-currently-raised-exception
import sys
e = sys.exc_info()[1]
from warnings import warn
warn("Error opening readline history file: %s" % e)
warn(f"Error opening readline history file: {e}", stacklevel=2)
readline.parse_and_bind("tab: complete")
......@@ -216,4 +216,27 @@ def shell(locals_=None, globals_=None):
# }}}
# {{{ estimate memory usage
def estimate_memory_usage(root, seen_ids=None):
if seen_ids is None:
seen_ids = set()
id_root = id(root)
if id_root in seen_ids:
return 0
seen_ids.add(id_root)
result = sys.getsizeof(root)
from gc import get_referents
for ref in get_referents(root):
result += estimate_memory_usage(ref, seen_ids=seen_ids)
return result
# }}}
# vim: foldmethod=marker
from __future__ import absolute_import
# Python decorator module
# by Michele Simionato
# http://www.phyast.pitt.edu/~micheles/python/
## The basic trick is to generate the source code for the decorated function
## with the right signature and to evaluate it.
## Uncomment the statement 'print >> sys.stderr, func_src' in _decorate
## to understand what is going on.
__all__ = ["decorator", "update_wrapper", "getinfo"]
import inspect
def getinfo(func):
"""
Returns an info dictionary containing:
- name (the name of the function : str)
- argnames (the names of the arguments : list)
- defaults (the values of the default arguments : tuple)
- signature (the signature : str)
- doc (the docstring : str)
- module (the module name : str)
- dict (the function __dict__ : str)
>>> def f(self, x=1, y=2, *args, **kw): pass
>>> info = getinfo(f)
>>> info["name"]
'f'
>>> info["argnames"]
['self', 'x', 'y', 'args', 'kw']
>>> info["defaults"]
(1, 2)
>>> info["signature"]
'self, x, y, *args, **kw'
"""
assert inspect.ismethod(func) or inspect.isfunction(func)
regargs, varargs, varkwargs, defaults = inspect.getargspec(func)
argnames = list(regargs)
if varargs:
argnames.append(varargs)
if varkwargs:
argnames.append(varkwargs)
signature = inspect.formatargspec(regargs, varargs, varkwargs, defaults,
formatvalue=lambda value: "")[1:-1]
return dict(name=func.__name__, argnames=argnames, signature=signature,
defaults = func.__defaults__, doc=func.__doc__,
module=func.__module__, dict=func.__dict__,
globals=func.__globals__, closure=func.__closure__)
def update_wrapper(wrapper, wrapped, create=False):
"""
An improvement over functools.update_wrapper. By default it works the
same, but if the 'create' flag is set, generates a copy of the wrapper
with the right signature and update the copy, not the original.
Moreovoer, 'wrapped' can be a dictionary with keys 'name', 'doc', 'module',
'dict', 'defaults'.
"""
if isinstance(wrapped, dict):
infodict = wrapped
else: # assume wrapped is a function
infodict = getinfo(wrapped)
assert not '_wrapper_' in infodict["argnames"], \
'"_wrapper_" is a reserved argument name!'
if create: # create a brand new wrapper with the right signature
src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict
# import sys; print >> sys.stderr, src # for debugging purposes
wrapper = eval(src, dict(_wrapper_=wrapper))
try:
wrapper.__name__ = infodict['name']
except: # Python version < 2.4
pass
wrapper.__doc__ = infodict['doc']
wrapper.__module__ = infodict['module']
wrapper.__dict__.update(infodict['dict'])
wrapper.__defaults__ = infodict['defaults']
return wrapper
# the real meat is here
def _decorator(caller, func):
if not (inspect.ismethod(func) or inspect.isfunction(func)):
# skip all the fanciness, just do what works
return lambda *args, **kwargs: caller(func, *args, **kwargs)
infodict = getinfo(func)
argnames = infodict['argnames']
assert not ('_call_' in argnames or '_func_' in argnames), \
'You cannot use _call_ or _func_ as argument names!'
src = "lambda %(signature)s: _call_(_func_, %(signature)s)" % infodict
dec_func = eval(src, dict(_func_=func, _call_=caller))
return update_wrapper(dec_func, func)
def decorator(caller, func=None):
"""
General purpose decorator factory: takes a caller function as
input and returns a decorator with the same attributes.
A caller function is any function like this::
def caller(func, *args, **kw):
# do something
return func(*args, **kw)
Here is an example of usage:
>>> @decorator
... def chatty(f, *args, **kw):
... print "Calling %r" % f.__name__
... return f(*args, **kw)
>>> chatty.__name__
'chatty'
>>> @chatty
... def f(): pass
...
>>> f()
Calling 'f'
For sake of convenience, the decorator factory can also be called with
two arguments. In this casem ``decorator(caller, func)`` is just a
shortcut for ``decorator(caller)(func)``.
"""
from warnings import warn
warn("pytools.decorator is deprecated and will be removed in pytools 12. "
"Use the 'decorator' module directly instead.",
DeprecationWarning, stacklevel=2)
if func is None: # return a decorator function
return update_wrapper(lambda f : _decorator(caller, f), caller)
else: # return a decorated function
return _decorator(caller, func)
if __name__ == "__main__":
import doctest; doctest.testmod()
####################### LEGALESE ##################################
## Redistributions of source code must retain the above copyright
## notice, this list of conditions and the following disclaimer.
## Redistributions in bytecode form must reproduce the above copyright
## notice, this list of conditions and the following disclaimer in
## the documentation and/or other materials provided with the
## distribution.
## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
## DAMAGE.
This diff is collapsed.
from __future__ import annotations
__copyright__ = """
Copyright (C) 2013 Andreas Kloeckner
Copyright (C) 2014 Matt Wala
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
__doc__ = """
Dot helper functions
====================
.. autofunction:: dot_escape
.. autofunction:: show_dot
"""
import html
import logging
import os
logger = logging.getLogger(__name__)
# {{{ graphviz / dot interactive show
def dot_escape(s: str) -> str:
"""
Escape the string *s* for compatibility with the
`dot <http://graphviz.org/>`__ language, particularly
backslashes and HTML tags.
:arg s: The input string to escape.
:returns: *s* with special characters escaped.
"""
# "\" and HTML are significant in graphviz.
return html.escape(s.replace("\\", "\\\\"))
def show_dot(dot_code: str, output_to: str | None = None) -> str | None:
"""
Visualize the graph represented by *dot_code*.
:arg dot_code: An instance of :class:`str` in the `dot <http://graphviz.org/>`__
language to visualize.
:arg output_to: An instance of :class:`str` that can be one of:
- ``"xwindow"`` to visualize the graph as an
`X window <https://en.wikipedia.org/wiki/X_Window_System>`_.
- ``"browser"`` to visualize the graph as an SVG file in the
system's default web-browser.
- ``"svg"`` to store the dot code as an SVG file on the file system.
Returns the path to the generated SVG file.
Defaults to ``"xwindow"`` if X11 support is present, otherwise defaults
to ``"browser"``.
:returns: Depends on *output_to*. If ``"svg"``, returns the path to the
generated SVG file, otherwise returns ``None``.
"""
import subprocess
from tempfile import mkdtemp
temp_dir = mkdtemp(prefix="tmp_pytools_dot")
dot_file_name = "code.dot"
from os.path import join
with open(join(temp_dir, dot_file_name), "w") as dotf:
dotf.write(dot_code)
# {{{ preprocess 'output_to'
if output_to is None:
with subprocess.Popen(["dot", "-T?"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
) as proc:
assert proc.stderr, ("Could not execute the 'dot' program. "
"Please install the 'graphviz' package and "
"make sure it is in your $PATH.")
supported_formats = proc.stderr.read().decode()
if " x11 " in supported_formats and "DISPLAY" in os.environ:
output_to = "xwindow"
else:
output_to = "browser"
# }}}
if output_to == "xwindow":
subprocess.check_call(["dot", "-Tx11", dot_file_name], cwd=temp_dir)
elif output_to in ["browser", "svg"]:
svg_file_name = "code.svg"
subprocess.check_call(["dot", "-Tsvg", "-o", svg_file_name, dot_file_name],
cwd=temp_dir)
full_svg_file_name = join(temp_dir, svg_file_name)
logger.info("show_dot: svg written to '%s'", full_svg_file_name)
if output_to == "svg":
return full_svg_file_name
assert output_to == "browser"
from webbrowser import open as browser_open
browser_open("file://" + full_svg_file_name)
else:
raise ValueError("`output_to` can be one of 'xwindow', 'browser', or 'svg',"
f" got '{output_to}'")
return None
# }}}
# vim: foldmethod=marker
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""See pytools.prefork for this module's reason for being."""
from __future__ import absolute_import
from __future__ import annotations
import mpi4py.rc # pylint:disable=import-error
mpi4py.rc.initialize = False
from mpi4py.MPI import * # noqa pylint:disable=wildcard-import,wrong-import-position
import pytools.prefork # pylint:disable=wrong-import-position
pytools.prefork.enable_prefork()
if Is_initialized(): # noqa pylint:disable=undefined-variable
# pylint: disable-next=undefined-variable
if Is_initialized(): # type: ignore[name-defined,unused-ignore] # noqa
raise RuntimeError("MPI already initialized before MPI wrapper import")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.