Skip to content
from __future__ import division, with_statement
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
......@@ -23,64 +24,21 @@ THE SOFTWARE.
"""
import marshal
import imp
from importlib.util import MAGIC_NUMBER as BYTECODE_VERSION
from types import FunctionType, ModuleType
import six
# loosely based on
# http://effbot.org/zone/python-code-generator.htm
class Indentation(object):
def __init__(self, generator):
self.generator = generator
def __enter__(self):
self.generator.indent()
def __exit__(self, exc_type, exc_val, exc_tb):
self.generator.dedent()
class PythonCodeGenerator(object):
def __init__(self):
self.preamble = []
self.code = []
self.level = 0
def extend(self, sub_generator):
for line in sub_generator.code:
self.code.append(" "*(4*self.level) + line)
def get(self):
result = "\n".join(self.code)
if self.preamble:
result = "\n".join(self.preamble) + "\n" + result
return result
def add_to_preamble(self, s):
self.preamble.append(s)
def __call__(self, s):
if not s.strip():
self.code.append("")
else:
if "\n" in s:
s = remove_common_indentation(s)
from pytools.codegen import ( # noqa
CodeGenerator as CodeGeneratorBase,
Indentation,
remove_common_indentation,
)
for l in s.split("\n"):
self.code.append(" "*(4*self.level) + l)
def indent(self):
self.level += 1
class PythonCodeGenerator(CodeGeneratorBase):
def get_module(self, name=None):
if name is None:
name = "<generated code>"
def dedent(self):
if self.level == 0:
raise RuntimeError("internal error in python code generator")
self.level -= 1
def get_module(self, name="<generated code>"):
result_dict = {}
source_text = self.get()
exec(compile(
......@@ -89,8 +47,8 @@ class PythonCodeGenerator(object):
result_dict["_MODULE_SOURCE_CODE"] = source_text
return result_dict
def get_picklable_module(self):
return PicklableModule(self.get_module())
def get_picklable_module(self, name=None):
return PicklableModule(self.get_module(name=name))
class PythonFunctionGenerator(PythonCodeGenerator):
......@@ -98,14 +56,18 @@ class PythonFunctionGenerator(PythonCodeGenerator):
PythonCodeGenerator.__init__(self)
self.name = name
self("def %s(%s):" % (name, ", ".join(args)))
self("def {}({}):".format(name, ", ".join(args)))
self.indent()
@property
def _gen_filename(self):
return f"<generated code for '{self.name}'>"
def get_function(self):
return self.get_module()[self.name]
return self.get_module(name=self._gen_filename)[self.name]
def get_picklable_function(self):
module = self.get_picklable_module()
module = self.get_picklable_module(name=self._gen_filename)
return PicklableFunction(module, self.name)
......@@ -120,7 +82,7 @@ def _get_empty_module_dict():
_empty_module_dict = _get_empty_module_dict()
class PicklableModule(object):
class PicklableModule:
def __init__(self, mod_globals):
self.mod_globals = mod_globals
......@@ -129,7 +91,7 @@ class PicklableModule(object):
functions = {}
modules = {}
for k, v in six.iteritems(self.mod_globals):
for k, v in self.mod_globals.items():
if isinstance(v, FunctionType):
functions[k] = (
v.__name__,
......@@ -140,7 +102,7 @@ class PicklableModule(object):
elif k not in _empty_module_dict:
nondefault_globals[k] = v
return (1, imp.get_magic(), functions, modules, nondefault_globals)
return (1, BYTECODE_VERSION, functions, modules, nondefault_globals)
def __setstate__(self, obj):
if obj[0] == 0:
......@@ -151,19 +113,19 @@ class PicklableModule(object):
else:
raise ValueError("unknown version of PicklableModule")
if magic != imp.get_magic():
raise ValueError("cannot unpickle function binary: "
"incorrect magic value (got: %s, expected: %s)"
% (magic, imp.get_magic()))
if magic != BYTECODE_VERSION:
raise ValueError(
"cannot unpickle function binary: incorrect magic value "
f"(got: {magic!r}, expected: {BYTECODE_VERSION!r})")
mod_globals = _empty_module_dict.copy()
mod_globals.update(nondefault_globals)
from pytools.importlib_backport import import_module
for k, mod_name in six.iteritems(modules):
from importlib import import_module
for k, mod_name in modules.items():
mod_globals[k] = import_module(mod_name)
for k, (name, code_bytes, argdefs) in six.iteritems(functions):
for k, (name, code_bytes, argdefs) in functions.items():
f = FunctionType(
marshal.loads(code_bytes), mod_globals, name=name,
argdefs=argdefs)
......@@ -176,8 +138,8 @@ class PicklableModule(object):
# {{{ picklable function
class PicklableFunction(object):
"""Convience class wrapping a function in a :class:`PicklableModule`.
class PicklableFunction:
"""Convenience class wrapping a function in a :class:`PicklableModule`.
"""
def __init__(self, module, name):
......@@ -199,33 +161,4 @@ class PicklableFunction(object):
# }}}
# {{{ remove common indentation
def remove_common_indentation(code, require_leading_newline=True):
if "\n" not in code:
return code
if require_leading_newline and not code.startswith("\n"):
return code
lines = code.split("\n")
while lines[0].strip() == "":
lines.pop(0)
while lines[-1].strip() == "":
lines.pop(-1)
if lines:
base_indent = 0
while lines[0][base_indent] in " \t":
base_indent += 1
for line in lines[1:]:
if line[:base_indent].strip():
raise ValueError("inconsistent indentation")
return "\n".join(line[base_indent:] for line in lines)
# }}}
# vim: foldmethod=marker
from __future__ import division, absolute_import
from six.moves import range
from __future__ import annotations
import numpy as np
......@@ -8,7 +7,7 @@ def do_boxes_intersect(bl, tr):
(bl1, tr1) = bl
(bl2, tr2) = tr
(dimension,) = bl1.shape
for i in range(0, dimension):
for i in range(dimension):
if max(bl1[i], bl2[i]) > min(tr1[i], tr2[i]):
return False
return True
......@@ -26,12 +25,11 @@ def make_buckets(bottom_left, top_right, allbuckets, max_elements_per_box):
max_elements_per_box=max_elements_per_box)
allbuckets.append(bucket)
return bucket
else:
pos[dimension] = 0
first = do(dimension + 1, pos)
pos[dimension] = 1
second = do(dimension + 1, pos)
return [first, second]
pos[dimension] = 0
first = do(dimension + 1, pos)
pos[dimension] = 1
second = do(dimension + 1, pos)
return [first, second]
return do(0, np.zeros((dimensions,), np.float64))
......@@ -106,7 +104,7 @@ class SpatialBinaryTreeBucket:
# No subdivisions yet.
if len(self.elements) > self.max_elements_per_box:
# Too many elements. Need to subdivide.
self.all_buckets = [] # noqa: E501 pylint:disable=attribute-defined-outside-init
self.all_buckets = []
self.buckets = make_buckets(
self.bottom_left, self.top_right,
self.all_buckets,
......@@ -138,26 +136,25 @@ class SpatialBinaryTreeBucket:
else:
bucket = bucket[1]
for result in bucket.generate_matches(point):
yield result
yield from bucket.generate_matches(point)
# Perform linear search.
for el, _ in self.elements:
yield el
def visualize(self, file):
file.write("%f %f\n" % (self.bottom_left[0], self.bottom_left[1]))
file.write("%f %f\n" % (self.top_right[0], self.bottom_left[1]))
file.write("%f %f\n" % (self.top_right[0], self.top_right[1]))
file.write("%f %f\n" % (self.bottom_left[0], self.top_right[1]))
file.write("%f %f\n\n" % (self.bottom_left[0], self.bottom_left[1]))
file.write(f"{self.bottom_left[0]:f} {self.bottom_left[1]:f}\n")
file.write(f"{self.top_right[0]:f} {self.bottom_left[1]:f}\n")
file.write(f"{self.top_right[0]:f} {self.top_right[1]:f}\n")
file.write(f"{self.bottom_left[0]:f} {self.top_right[1]:f}\n")
file.write(f"{self.bottom_left[0]:f} {self.bottom_left[1]:f}\n\n")
if self.buckets:
for i in self.all_buckets:
i.visualize(file)
def plot(self, **kwargs):
import matplotlib.pyplot as pt
import matplotlib.patches as mpatches
import matplotlib.pyplot as pt
from matplotlib.path import Path
el = self.bottom_left
......@@ -170,7 +167,7 @@ class SpatialBinaryTreeBucket:
(Path.CLOSEPOLY, (el[0], el[1])),
]
codes, verts = zip(*pathdata)
codes, verts = zip(*pathdata, strict=True)
path = Path(verts, codes)
patch = mpatches.PathPatch(path, **kwargs)
pt.gca().add_patch(patch)
......
from __future__ import division, absolute_import, print_function
from __future__ import annotations
import time
import pytools
from pytools import DependentDictionary, Reference
class StopWatch:
def __init__(self):
self.Elapsed = 0.
self.LastStart = None
def __init__(self) -> None:
self.Elapsed = 0.0
self.LastStart: float | None = None
def start(self):
def start(self) -> StopWatch:
assert self.LastStart is None
self.LastStart = time.time()
return self
def stop(self):
def stop(self) -> StopWatch:
assert self.LastStart is not None
self.Elapsed += time.time() - self.LastStart
self.LastStart = None
return self
def elapsed(self):
def elapsed(self) -> float:
if self.LastStart:
return time.time() - self.LastStart + self.Elapsed
else:
return self.Elapsed
return self.Elapsed
class Job:
def __init__(self, name):
def __init__(self, name: str) -> None:
self.Name = name
self.StopWatch = StopWatch().start()
if self.is_visible():
print("%s..." % name)
print(f"{name}...")
def done(self):
def done(self) -> None:
elapsed = self.StopWatch.elapsed()
JOB_TIMES[self.Name] += elapsed
if self.is_visible():
print(" " * (len(self.Name) + 2), elapsed, "seconds")
def is_visible(self):
def is_visible(self) -> bool:
if PRINT_JOBS.get():
return self.Name not in HIDDEN_JOBS
else:
return self.Name in VISIBLE_JOBS
return self.Name in VISIBLE_JOBS
class EtaEstimator:
def __init__(self, total_steps):
def __init__(self, total_steps: int) -> None:
self.stopwatch = StopWatch().start()
self.total_steps = total_steps
assert total_steps > 0
def estimate(self, done):
fraction_done = done/self.total_steps
def estimate(self, done: int) -> float | None:
fraction_done = done / self.total_steps
time_spent = self.stopwatch.elapsed()
if fraction_done > 1e-5:
return time_spent/fraction_done-time_spent
else:
return None
if fraction_done > 1.0e-5:
return time_spent / fraction_done - time_spent
return None
def print_job_summary():
for key in JOB_TIMES:
print(key, " " * (50-len(key)), JOB_TIMES[key])
def print_job_summary() -> None:
for key, value in JOB_TIMES.iteritems():
print(key, " " * (50 - len(key)), value)
HIDDEN_JOBS = []
VISIBLE_JOBS = []
JOB_TIMES = pytools.DependentDictionary(lambda x: 0)
PRINT_JOBS = pytools.Reference(True)
HIDDEN_JOBS: list[str] = []
VISIBLE_JOBS: list[str] = []
JOB_TIMES = DependentDictionary(lambda x: 0)
PRINT_JOBS = Reference(True)
"""
Tag Interface
---------------
.. ``normalize_tags`` undocumented for now. (Not ready to commit.)
.. autofunction:: check_tag_uniqueness
.. autoclass:: Taggable
.. autoclass:: Tag
.. autoclass:: UniqueTag
Supporting Functionality
------------------------
.. autoclass:: DottedName
.. autoclass:: NonUniqueTagError
Internal stuff that is only here because the documentation tool wants it
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. class:: TagT
A type variable with lower bound :class:`Tag`.
"""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypeVar
from warnings import warn
from typing_extensions import Self, dataclass_transform
from pytools import memoize, memoize_method
__copyright__ = """
Copyright (C) 2020 Andreas Klöckner
Copyright (C) 2020 Matt Wala
Copyright (C) 2020 Xiaoyu Wei
Copyright (C) 2020 Nicholas Christensen
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
# {{{ dotted name
class DottedName:
"""
.. attribute:: name_parts
A tuple of strings, each of which is a valid
Python identifier. No name part may start with
a double underscore.
The name (at least morally) exists in the
name space defined by the Python module system.
It need not necessarily identify an importable
object.
.. automethod:: from_class
"""
def __init__(self, name_parts: tuple[str, ...]) -> None:
if len(name_parts) == 0:
raise ValueError("empty name parts")
for p in name_parts:
if not p.isidentifier():
raise ValueError(f"{p} is not a Python identifier")
self.name_parts = name_parts
@classmethod
def from_class(cls, argcls: Any) -> DottedName:
name_parts = tuple(
[str(part) for part in argcls.__module__.split(".")]
+ [str(argcls.__name__)])
if not all(not npart.startswith("__") for npart in name_parts):
raise ValueError(f"some name parts of {'.'.join(name_parts)} "
"start with double underscores")
return cls(name_parts)
def __repr__(self) -> str:
return self.__class__.__name__ + repr(self.name_parts)
def __eq__(self, other: object) -> bool:
if isinstance(other, DottedName):
return self.name_parts == other.name_parts
return False
# }}}
# {{{ tag
T = TypeVar("T")
@dataclass_transform(eq_default=True, frozen_default=True)
def tag_dataclass(cls: type[T]) -> type[T]:
return dataclass(init=True, frozen=True, eq=True, repr=True)(cls)
@tag_dataclass
class Tag:
"""
Generic metadata, applied to, among other things,
pytato Arrays.
.. attribute:: tag_name
A fully qualified :class:`DottedName` that reflects
the class name of the tag.
Instances of this type must be immutable, hashable,
picklable, and have a reasonably concise :meth:`__repr__`
of the form ``dotted.name(attr1=value1, attr2=value2)``.
Positional arguments are not allowed.
.. automethod:: __repr__
"""
@property
def tag_name(self) -> DottedName:
return DottedName.from_class(type(self))
# }}}
# {{{ unique tag
@tag_dataclass
class UniqueTag(Tag):
"""
A superclass for tags that are unique on each :class:`Taggable`.
Each instance of :class:`Taggable` may have no more than one
instance of each subclass of :class:`UniqueTag` in its
set of `tags`. Multiple `UniqueTag` instances of
different (immediate) subclasses are allowed.
"""
# }}}
ToTagSetConvertible = Iterable[Tag] | Tag | None
TagT = TypeVar("TagT", bound="Tag")
# {{{ UniqueTag rules checking
@memoize
def _immediate_unique_tag_descendants(cls: type[Tag]) -> frozenset[type[Tag]]:
if UniqueTag in cls.__bases__:
return frozenset([cls])
result: frozenset[type[Tag]] = frozenset()
for base in cls.__bases__:
result = result | _immediate_unique_tag_descendants(base)
return result
class NonUniqueTagError(ValueError):
"""
Raised when a :class:`Taggable` object is instantiated with more
than one :class:`UniqueTag` instances of the same subclass in
its set of tags.
"""
def check_tag_uniqueness(tags: frozenset[Tag]) -> frozenset[Tag]:
"""Ensure that *tags* obeys the rules set forth in :class:`UniqueTag`.
If not, raise :exc:`NonUniqueTagError`. If any *tags* are not
subclasses of :class:`Tag`, a :exc:`TypeError` will be raised.
:returns: *tags*
"""
unique_tag_descendants: set[type[Tag]] = set()
for tag in tags:
if not isinstance(tag, Tag):
raise TypeError(f"'{tag}' is not an instance of pytools.tag.Tag")
tag_unique_tag_descendants = _immediate_unique_tag_descendants(
type(tag))
intersection = unique_tag_descendants & tag_unique_tag_descendants
if intersection:
raise NonUniqueTagError("Multiple tags are direct subclasses of "
"the following UniqueTag(s): "
f"{', '.join(d.__name__ for d in intersection)}")
unique_tag_descendants.update(tag_unique_tag_descendants)
return tags
# }}}
def normalize_tags(tags: ToTagSetConvertible) -> frozenset[Tag]:
if isinstance(tags, Tag):
tags = frozenset([tags])
elif tags is None:
tags = frozenset()
else:
tags = frozenset(tags)
return tags
# {{{ taggable
class Taggable:
"""
Parent class for objects with a `tags` attribute.
.. autoattribute:: tags
.. automethod:: _with_new_tags
.. automethod:: tagged
.. automethod:: without_tags
.. automethod:: tags_of_type
.. automethod:: tags_not_of_type
.. versionadded:: 2021.1
"""
if not TYPE_CHECKING:
def __init__(self, tags: frozenset[Tag] = frozenset()):
warn("The Taggable constructor is deprecated. "
"Subclasses must declare their own storage for .tags. "
"The constructor will disappear in 2025.x.",
DeprecationWarning, stacklevel=2)
self.tags = tags
# ReST references in docstrings must be fully qualified, as docstrings may
# be inherited and appear in different contexts.
# type-checking only so that self.tags = ... in subclasses still works
if TYPE_CHECKING:
@property
def tags(self) -> frozenset[Tag]:
...
def _with_new_tags(self, tags: frozenset[Tag]) -> Self:
"""
Returns a copy of *self* with the specified tags. This method
should be overridden by subclasses.
"""
raise NotImplementedError
def tagged(self, tags: ToTagSetConvertible) -> Self:
"""
Return a copy of *self* with the specified
tag or tags added to the set of tags. If the resulting set of
tags violates the rules on :class:`pytools.tag.UniqueTag`,
an error is raised.
:arg tags: An instance of :class:`~pytools.tag.Tag` or
an iterable with instances therein.
"""
return self._with_new_tags(
tags=check_tag_uniqueness(normalize_tags(tags) | self.tags))
def without_tags(self,
tags: ToTagSetConvertible, verify_existence: bool = True
) -> Self:
"""
Return a copy of *self* without the specified tags.
:arg tags: An instance of :class:`~pytools.tag.Tag` or an iterable with
instances therein.
:arg verify_existence: If set to `True`, this method raises
an exception if not all tags specified for removal are
present in the original set of tags. Default `True`.
"""
to_remove = normalize_tags(tags)
new_tags = self.tags - to_remove
if verify_existence and len(new_tags) > len(self.tags) - len(to_remove):
raise ValueError("A tag specified for removal was not present.")
return self._with_new_tags(tags=check_tag_uniqueness(new_tags))
@memoize_method
def tags_of_type(self, tag_t: type[TagT]) -> frozenset[TagT]:
"""
Returns *self*'s tags of type *tag_t*.
"""
return frozenset({tag
for tag in self.tags
if isinstance(tag, tag_t)})
@memoize_method
def tags_not_of_type(self, tag_t: type[TagT]) -> frozenset[Tag]:
"""
Returns *self*'s tags that are not of type *tag_t*.
"""
return frozenset({tag
for tag in self.tags
if not isinstance(tag, tag_t)})
def __eq__(self, other: object) -> bool:
if isinstance(other, Taggable):
return self.tags == other.tags
return super().__eq__(other)
def __hash__(self) -> int:
return hash(self.tags)
# }}}
# {{{ deprecation
_depr_name_to_replacement_and_obj = {
"TagsType": (
"frozenset[Tag]",
frozenset[Tag], 2023),
"TagOrIterableType": (
"ToTagSetConvertible",
ToTagSetConvertible, 2023),
"T_co": (
"Self (i.e. the self type from Python 3.11)",
TypeVar("TaggableT", bound="Taggable"), 2023),
}
def __getattr__(name: str) -> Any:
replacement_and_obj = _depr_name_to_replacement_and_obj.get(name)
if replacement_and_obj is not None:
replacement, obj, year = replacement_and_obj
from warnings import warn
warn(f"'pytools.tag.{name}' is deprecated. "
f"Use '{replacement}' instead. "
f"'pytools.tag.{name}' will continue to work until {year}.",
DeprecationWarning, stacklevel=2)
return obj
raise AttributeError(name)
# }}}
# vim: foldmethod=marker
from __future__ import absolute_import
try:
from py.test import mark as mark_test # pylint:disable=unused-import
except ImportError:
class _Mark:
def __getattr__(self, name):
def dec(f):
return f
return dec
mark_test = _Mark()
from __future__ import division
from __future__ import absolute_import
from six.moves import range
from six.moves import zip
# data from Wikipedia "join" article
from __future__ import annotations
def get_dept_table():
......@@ -74,7 +71,7 @@ def test_aggregate():
def test_aggregate_2():
from pytools.datatable import DataTable
tbl = DataTable(["step", "value"], list(zip(list(range(20)), list(range(20)))))
tbl = DataTable(["step", "value"], list(zip(range(20), range(20), strict=True)))
agg = tbl.aggregated(["step"], "value", max)
assert agg.column_data("step") == list(range(20))
assert agg.column_data("value") == list(range(20))
......
from __future__ import annotations
__copyright__ = "Copyright (C) 2024 University of Illinois Board of Trustees"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import sys
import pytest
from pytools import opt_frozen_dataclass
def test_opt_frozen_dataclass() -> None:
# {{{ basic usage
@opt_frozen_dataclass()
class A:
x: int
a = A(1)
assert a.x == 1
# Needs to be hashable by default, not using object.__hash__
hash(a)
assert hash(a) == hash(A(1))
assert a == A(1)
# Needs to be frozen by default
if __debug__:
with pytest.raises(AttributeError):
a.x = 2 # type: ignore[misc]
else:
a.x = 2 # type: ignore[misc]
assert a.__dataclass_params__.frozen is __debug__ # type: ignore[attr-defined] # pylint: disable=no-member
# }}}
with pytest.raises(TypeError):
# Can't specify frozen parameter
@opt_frozen_dataclass(frozen=False) # type: ignore[call-arg] # pylint: disable=unexpected-keyword-arg
class B:
x: int
# {{{ eq=False
@opt_frozen_dataclass(eq=False)
class C:
x: int
c = C(1)
# Hashing still works, but uses object.__hash__ (i.e., id())
assert hash(c) != hash(C(1))
# Equality is not defined and uses id()
assert c != C(1)
# }}}
def test_dataclass_weakref() -> None:
if sys.version_info < (3, 11):
pytest.skip("weakref support needs Python 3.11+")
@opt_frozen_dataclass(weakref_slot=True, slots=True)
class Weakref:
x: int
a = Weakref(1)
assert a.x == 1
import weakref
ref = weakref.ref(a)
_ = ref().x
with pytest.raises(TypeError):
@opt_frozen_dataclass(weakref_slot=True) # needs slots=True to work
class Weakref2:
x: int
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])
from __future__ import annotations
import sys
import pytest
def test_compute_sccs():
import random
from pytools.graph import compute_sccs
rng = random.Random(0)
def generate_random_graph(nnodes):
graph = {i: set() for i in range(nnodes)}
for i in range(nnodes):
for j in range(nnodes):
# Edge probability 2/n: Generates decently interesting inputs.
if rng.randint(0, nnodes - 1) <= 1:
graph[i].add(j)
return graph
def verify_sccs(graph, sccs):
visited = set()
def visit(node):
if node in visited:
return []
visited.add(node)
result = []
for child in graph[node]:
result = result + visit(child)
return [*result, node]
for scc in sccs:
scc = set(scc)
assert not scc & visited
# Check that starting from each element of the SCC results
# in the same set of reachable nodes.
for scc_root in scc:
visited.difference_update(scc)
result = visit(scc_root)
assert set(result) == scc, (set(result), scc)
for nnodes in range(10, 20):
for _ in range(40):
graph = generate_random_graph(nnodes)
verify_sccs(graph, compute_sccs(graph))
def test_compute_topological_order():
from pytools.graph import CycleError, compute_topological_order
empty = {}
assert compute_topological_order(empty) == []
disconnected = {1: [], 2: [], 3: []}
assert len(compute_topological_order(disconnected)) == 3
line = list(zip(range(10), ([i] for i in range(1, 11)), strict=True))
import random
random.seed(0)
random.shuffle(line)
expected = list(range(11))
assert compute_topological_order(dict(line)) == expected
claw = {1: [2, 3], 0: [1]}
assert compute_topological_order(claw)[:2] == [0, 1]
repeated_edges = {1: [2, 2], 2: [0]}
assert compute_topological_order(repeated_edges) == [1, 2, 0]
self_cycle = {1: [1]}
with pytest.raises(CycleError):
compute_topological_order(self_cycle)
cycle = {0: [2], 1: [2], 2: [3], 3: [4, 1]}
with pytest.raises(CycleError):
compute_topological_order(cycle)
def test_transitive_closure():
from pytools.graph import compute_transitive_closure
# simple test
graph = {
1: {2},
2: {3},
3: {4},
4: set(),
}
expected_closure = {
1: {2, 3, 4},
2: {3, 4},
3: {4},
4: set(),
}
closure = compute_transitive_closure(graph)
assert closure == expected_closure
# test with branches that reconnect
graph = {
1: {2},
2: set(),
3: {1},
4: {1},
5: {6, 7},
6: {7},
7: {1},
8: {3, 4},
}
expected_closure = {
1: {2},
2: set(),
3: {1, 2},
4: {1, 2},
5: {1, 2, 6, 7},
6: {1, 2, 7},
7: {1, 2},
8: {1, 2, 3, 4},
}
closure = compute_transitive_closure(graph)
assert closure == expected_closure
# test with cycles
graph = {
1: {2},
2: {3},
3: {4},
4: {1},
}
expected_closure = {
1: {1, 2, 3, 4},
2: {1, 2, 3, 4},
3: {1, 2, 3, 4},
4: {1, 2, 3, 4},
}
closure = compute_transitive_closure(graph)
assert closure == expected_closure
def test_graph_cycle_finder():
from pytools.graph import contains_cycle
graph = {
"a": {"b", "c"},
"b": {"d", "e"},
"c": {"d", "f"},
"d": set(),
"e": set(),
"f": {"g"},
"g": set(),
}
assert not contains_cycle(graph)
graph = {
"a": {"b", "c"},
"b": {"d", "e"},
"c": {"d", "f"},
"d": set(),
"e": set(),
"f": {"g"},
"g": {"a"},
}
assert contains_cycle(graph)
graph = {
"a": {"a", "c"},
"b": {"d", "e"},
"c": {"d", "f"},
"d": set(),
"e": set(),
"f": {"g"},
"g": set(),
}
assert contains_cycle(graph)
graph = {
"a": {"a"},
}
assert contains_cycle(graph)
def test_induced_subgraph():
from pytools.graph import compute_induced_subgraph
graph = {
"a": {"b", "c"},
"b": {"d", "e"},
"c": {"d", "f"},
"d": set(),
"e": set(),
"f": {"g"},
"g": {"h", "i", "j"},
}
node_subset = {"b", "c", "e", "f", "g"}
expected_subgraph = {
"b": {"e"},
"c": {"f"},
"e": set(),
"f": {"g"},
"g": set(),
}
subgraph = compute_induced_subgraph(graph, node_subset)
assert subgraph == expected_subgraph
def test_prioritized_topological_sort_examples():
from pytools.graph import compute_topological_order
keys = {"a": 4, "b": 3, "c": 2, "e": 1, "d": 4}
dag = {
"a": ["b", "c"],
"b": [],
"c": ["d", "e"],
"d": [],
"e": []}
assert compute_topological_order(dag, key=keys.get) == [
"a", "c", "e", "b", "d"]
keys = {"a": 7, "b": 2, "c": 1, "d": 0}
dag = {
"d": set("c"),
"b": set("a"),
"a": set(),
"c": set("a"),
}
assert compute_topological_order(dag, key=keys.get) == ["d", "c", "b", "a"]
def test_prioritized_topological_sort():
import random
from pytools.graph import compute_topological_order
rng = random.Random(0)
def generate_random_graph(nnodes):
graph = {i: set() for i in range(nnodes)}
for i in range(nnodes):
# to avoid cycles only consider edges node_i->node_j where j > i.
for j in range(i+1, nnodes):
# Edge probability 4/n: Generates decently interesting inputs.
if rng.randint(0, nnodes - 1) <= 2:
graph[i].add(j)
return graph
nnodes = rng.randint(40, 100)
rev_dep_graph = generate_random_graph(nnodes)
dep_graph = {i: set() for i in range(nnodes)}
for i in range(nnodes):
for rev_dep in rev_dep_graph[i]:
dep_graph[rev_dep].add(i)
keys = [rng.random() for _ in range(nnodes)]
topo_order = compute_topological_order(rev_dep_graph, key=keys.__getitem__)
for scheduled_node in topo_order:
nodes_with_no_deps = {node for node, deps in dep_graph.items()
if len(deps) == 0}
# check whether the order is a valid topological order
assert scheduled_node in nodes_with_no_deps
# check whether priorities are upheld
assert keys[scheduled_node] == min(
keys[node] for node in nodes_with_no_deps)
# 'scheduled_node' is scheduled => no longer a dependency
dep_graph.pop(scheduled_node)
for deps in dep_graph.values():
deps.discard(scheduled_node)
assert len(dep_graph) == 0
def test_as_graphviz_dot():
graph = {"A": ["B", "C"],
"B": [],
"C": ["A"]}
from pytools.graph import NodeT, as_graphviz_dot
def edge_labels(n1: NodeT, n2: NodeT) -> str:
if n1 == "A" and n2 == "B":
return "foo"
return ""
def node_labels(node: NodeT) -> str:
if node == "A":
return "foonode"
return str(node)
res = as_graphviz_dot(graph, node_labels=node_labels, edge_labels=edge_labels)
assert res == \
"""digraph mygraph {
mynodeid [label="foonode"];
mynodeid_0 [label="B"];
mynodeid_1 [label="C"];
mynodeid -> mynodeid_0 [label="foo"];
mynodeid -> mynodeid_1 [label=""];
mynodeid_1 -> mynodeid [label=""];
}
"""
def test_reverse_graph():
graph = {
"a": frozenset(("b", "c")),
"b": frozenset(("d", "e")),
"c": frozenset(("d", "f")),
"d": frozenset(),
"e": frozenset(),
"f": frozenset(("g",)),
"g": frozenset(("h", "i", "j")),
"h": frozenset(),
"i": frozenset(),
"j": frozenset(),
}
from pytools.graph import reverse_graph
assert graph == reverse_graph(reverse_graph(graph))
def test_validate_graph():
from pytools.graph import validate_graph
graph1 = {
"d": set("c"),
"b": set("a"),
"a": set(),
"c": set("a"),
}
validate_graph(graph1)
graph2 = {
"d": set("d"),
"b": set("c"),
"a": set("b"),
"c": set("a"),
}
validate_graph(graph2)
graph3 = {
"a": {"b", "c"},
"b": {"d", "e"},
"c": {"d", "f"},
"d": set(),
"e": set(),
"f": {"g"},
"g": {"h", "i", "j"}, # h, i, j missing from keys
}
with pytest.raises(ValueError):
validate_graph(graph3)
validate_graph({})
def test_is_connected():
from pytools.graph import is_connected
graph1 = {
"d": set("c"),
"b": set("a"),
"a": set(),
"c": set("a"),
}
assert is_connected(graph1)
graph2 = {
"d": set("d"),
"b": set("c"),
"a": set("b"),
"c": set("a"),
}
assert not is_connected(graph2)
graph3 = {
"a": {"b", "c"},
"b": {"d", "e"},
"c": {"d", "f"},
"d": set(),
"e": set(),
"f": {"g"},
"g": {},
}
assert is_connected(graph3)
graph4 = {
"a": {"c"},
"b": {"d", "e"},
"c": {"f"},
"d": set(),
"e": set(),
"f": {"g"},
"g": {},
}
assert not is_connected(graph4)
assert is_connected({})
def test_propagation_graph_tools():
from pytools.graph import (
get_reachable_nodes,
undirected_graph_from_edges,
)
vars = {"a", "b", "c", "d", "e", "f", "g"}
constraints = {
("a", "b"),
("b", "c"),
("b", "d"),
("c", "e"),
("d", "f"),
("e", "g"),
("g", "f"),
("f", "g")
}
all_reachable_nodes = {
"a": frozenset({"a", "b"}),
"b": frozenset({"a", "b"}),
"c": frozenset(),
"d": frozenset(),
"e": frozenset({"e", "f", "g"}),
"f": frozenset({"e", "f", "g"}),
"g": frozenset({"e", "f", "g"})
}
exclude_nodes = {"d", "c"}
propagation_graph = undirected_graph_from_edges(constraints)
assert (
all_reachable_nodes[var] == get_reachable_nodes(propagation_graph, var,
exclude_nodes)
for var in vars
)
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])
from __future__ import division
from __future__ import absolute_import
from __future__ import annotations
def test_variance():
......
from __future__ import annotations
__copyright__ = "Copyright (C) 2022 University of Illinois Board of Trustees"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import pytest
def test_pytest_raises_on_rank():
from pytools.mpi import pytest_raises_on_rank
def fail(my_rank: int, fail_rank: int) -> None:
if my_rank == fail_rank:
raise ValueError("test failure")
with pytest.raises(ValueError):
fail(0, 0)
fail(0, 1)
with pytest_raises_on_rank(0, 0, ValueError):
# Generates an exception, and pytest_raises_on_rank
# expects one.
fail(0, 0)
with pytest_raises_on_rank(0, 1, ValueError):
# Generates no exception, and pytest_raises_on_rank
# does not expect one.
fail(0, 1)
from __future__ import annotations
import shutil
import sys
import tempfile
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Any
import pytest
from pytools.persistent_dict import (
CollisionWarning,
KeyBuilder,
NoSuchEntryCollisionError,
NoSuchEntryError,
PersistentDict,
ReadOnlyEntryError,
WriteOncePersistentDict,
)
from pytools.tag import Tag, tag_dataclass
# {{{ type for testing
class PDictTestingKeyOrValue:
def __init__(self, val: Any, hash_key=None) -> None:
self.val = val
if hash_key is None:
hash_key = val
self.hash_key = hash_key
def __getstate__(self) -> dict[str, Any]:
return {"val": self.val, "hash_key": self.hash_key}
def __eq__(self, other: object) -> bool:
if type(other) is not PDictTestingKeyOrValue:
return False
return self.val == other.val
def update_persistent_hash(self, key_hash: Any, key_builder: KeyBuilder) -> None:
key_builder.rec(key_hash, self.hash_key)
def __repr__(self) -> str:
return "PDictTestingKeyOrValue(val={!r},hash_key={!r})".format(
self.val, self.hash_key)
# }}}
@tag_dataclass
class SomeTag(Tag):
value: str
class MyEnum(Enum):
YES = 1
NO = 2
class MyIntEnum(IntEnum):
YES = 1
NO = 2
@dataclass
class MyStruct:
name: str
value: int
def test_persistent_dict_storage_and_lookup() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: PersistentDict[Any, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
from random import randrange
def rand_str(n=20):
return "".join(
chr(65+randrange(26))
for i in range(n))
keys = [
(randrange(2000)-1000, rand_str(), None,
SomeTag(rand_str()),
frozenset({"abc", 123}))
for i in range(20)]
values = [randrange(2000) for i in range(20)]
d = dict(zip(keys, values, strict=True))
# {{{ check lookup
for k, v in zip(keys, values, strict=True):
pdict[k] = v
for k, v in d.items():
assert d[k] == pdict[k]
assert v == pdict[k]
# }}}
# {{{ check updating
for k, v in zip(keys, values, strict=True):
pdict[k] = v + 1
for k, v in d.items():
assert d[k] + 1 == pdict[k]
assert v + 1 == pdict[k]
# }}}
# {{{ check store_if_not_present
for k, _ in zip(keys, values, strict=True):
pdict.store_if_not_present(k, d[k] + 2)
for k, v in d.items():
assert d[k] + 1 == pdict[k]
assert v + 1 == pdict[k]
pdict.store_if_not_present(2001, 2001)
assert pdict[2001] == 2001
# }}}
# {{{ check dataclasses
for v in [17, 18]:
key = MyStruct("hi", v)
pdict[key] = v
# reuse same key, with stored hash
assert pdict[key] == v
with pytest.raises(NoSuchEntryError):
pdict[MyStruct("hi", 19)]
for v in [17, 18]:
# make new key instances
assert pdict[MyStruct("hi", v)] == v
# }}}
# {{{ check enums
pdict[MyEnum.YES] = 1
with pytest.raises(NoSuchEntryError):
pdict[MyEnum.NO]
assert pdict[MyEnum.YES] == 1
pdict[MyIntEnum.YES] = 12
with pytest.raises(NoSuchEntryError):
pdict[MyIntEnum.NO]
assert pdict[MyIntEnum.YES] == 12
# }}}
# check not found
with pytest.raises(NoSuchEntryError):
pdict.fetch(3000)
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_deletion() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
pdict[0] = 0
del pdict[0]
with pytest.raises(NoSuchEntryError):
pdict.remove(0)
with pytest.raises(NoSuchEntryError):
pdict.fetch(0)
with pytest.raises(NoSuchEntryError):
del pdict[1]
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_synchronization() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict1: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
pdict2: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
# check lookup
pdict1[0] = 1
assert pdict2[0] == 1
# check updating
pdict1[0] = 2
assert pdict2[0] == 2
# check deletion
del pdict1[0]
with pytest.raises(NoSuchEntryError):
pdict2.fetch(0)
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_cache_collisions() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: PersistentDict[PDictTestingKeyOrValue, int] = \
PersistentDict("pytools-test", container_dir=tmpdir, safe_sync=False)
key1 = PDictTestingKeyOrValue(1, hash_key=0)
key2 = PDictTestingKeyOrValue(2, hash_key=0)
pdict[key1] = 1
# check lookup
with pytest.warns(CollisionWarning):
with pytest.raises(NoSuchEntryCollisionError):
pdict.fetch(key2)
# check deletion
with pytest.warns(CollisionWarning):
with pytest.raises(NoSuchEntryCollisionError):
del pdict[key2]
# check presence after deletion
assert pdict[key1] == 1
# check store_if_not_present
pdict.store_if_not_present(key2, 2)
assert pdict[key1] == 1
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_clear() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
pdict[0] = 1
pdict.fetch(0)
pdict.clear()
with pytest.raises(NoSuchEntryError):
pdict.fetch(0)
finally:
shutil.rmtree(tmpdir)
@pytest.mark.parametrize("in_mem_cache_size", (0, 256))
def test_write_once_persistent_dict_storage_and_lookup(in_mem_cache_size) -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: WriteOncePersistentDict[int, int] = WriteOncePersistentDict(
"pytools-test", container_dir=tmpdir,
in_mem_cache_size=in_mem_cache_size, safe_sync=False)
# check lookup
pdict[0] = 1
assert pdict[0] == 1
# do two lookups to test the cache
assert pdict[0] == 1
# check updating
with pytest.raises(ReadOnlyEntryError):
pdict[0] = 2
# check not found
with pytest.raises(NoSuchEntryError):
pdict.fetch(1)
# check store_if_not_present
pdict.store_if_not_present(0, 2)
assert pdict[0] == 1
pdict.store_if_not_present(1, 1)
assert pdict[1] == 1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_lru_policy() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: WriteOncePersistentDict[Any, Any] = WriteOncePersistentDict(
"pytools-test", container_dir=tmpdir, in_mem_cache_size=3,
safe_sync=False)
pdict[1] = PDictTestingKeyOrValue(1)
pdict[2] = PDictTestingKeyOrValue(2)
pdict[3] = PDictTestingKeyOrValue(3)
pdict[4] = PDictTestingKeyOrValue(4)
val1 = pdict.fetch(1)
assert pdict.fetch(1) is val1
pdict.fetch(2)
assert pdict.fetch(1) is val1
pdict.fetch(2)
pdict.fetch(3)
assert pdict.fetch(1) is val1
pdict.fetch(2)
pdict.fetch(3)
pdict.fetch(2)
assert pdict.fetch(1) is val1
pdict.fetch(2)
pdict.fetch(3)
pdict.fetch(4)
assert pdict.fetch(1) is not val1
# test clear_in_mem_cache
val1 = pdict.fetch(1)
pdict.clear_in_mem_cache()
assert pdict.fetch(1) is not val1
val1 = pdict.fetch(1)
assert pdict.fetch(1) is val1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_synchronization() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict1: WriteOncePersistentDict[int, int] = \
WriteOncePersistentDict("pytools-test", container_dir=tmpdir,
safe_sync=False)
pdict2: WriteOncePersistentDict[int, int] = \
WriteOncePersistentDict("pytools-test", container_dir=tmpdir,
safe_sync=False)
# check lookup
pdict1[1] = 0
assert pdict2[1] == 0
# check updating
with pytest.raises(ReadOnlyEntryError):
pdict2[1] = 1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_cache_collisions() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: WriteOncePersistentDict[Any, int] = \
WriteOncePersistentDict("pytools-test", container_dir=tmpdir,
safe_sync=False)
key1 = PDictTestingKeyOrValue(1, hash_key=0)
key2 = PDictTestingKeyOrValue(2, hash_key=0)
pdict[key1] = 1
# check lookup
with pytest.warns(CollisionWarning):
with pytest.raises(NoSuchEntryCollisionError):
pdict.fetch(key2)
# check update
with pytest.raises(ReadOnlyEntryError):
pdict[key2] = 1
# check store_if_not_present
pdict.store_if_not_present(key2, 2)
assert pdict[key1] == 1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_clear() -> None:
try:
tmpdir = tempfile.mkdtemp()
pdict: WriteOncePersistentDict[int, int] = \
WriteOncePersistentDict("pytools-test", container_dir=tmpdir,
safe_sync=False)
pdict[0] = 1
pdict.fetch(0)
pdict.clear()
with pytest.raises(NoSuchEntryError):
pdict.fetch(0)
finally:
shutil.rmtree(tmpdir)
def test_dtype_hashing() -> None:
np = pytest.importorskip("numpy")
keyb = KeyBuilder()
assert keyb(np.float32) == keyb(np.float32)
assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32))
def test_bool_hashing() -> None:
keyb = KeyBuilder()
assert keyb(True) == keyb(True)
assert keyb(False) == keyb(False)
assert keyb(True) != keyb(False)
np = pytest.importorskip("numpy")
bool_types = [np.bool_]
if hasattr(np, "bool"):
bool_types.append(np.bool)
for bool_type in bool_types:
assert keyb(bool_type) != keyb(bool)
assert keyb(bool_type(True)) == keyb(bool_type(True))
assert keyb(bool_type(False)) == keyb(bool_type(False))
assert keyb(bool_type(True)) != keyb(bool_type(False))
assert keyb(bool_type) != keyb(np.dtype(bool_type))
assert keyb(bool_type(True)) != keyb(np.dtype(bool_type(True)))
assert keyb(bool_type(False)) != keyb(np.dtype(bool_type(False)))
def test_scalar_hashing() -> None:
keyb = KeyBuilder()
assert keyb(1) == keyb(1)
assert keyb(2) != keyb(1)
assert keyb(1.1) == keyb(1.1)
assert keyb(1+4j) == keyb(1+4j)
try:
import numpy as np
except ImportError:
return
assert keyb(np.int8(1)) == keyb(np.int8(1))
assert keyb(np.int16(1)) == keyb(np.int16(1))
assert keyb(np.int32(1)) == keyb(np.int32(1))
assert keyb(np.int32(2)) != keyb(np.int32(1))
assert keyb(np.int64(1)) == keyb(np.int64(1))
assert keyb(1) == keyb(np.int64(1))
assert keyb(1) != keyb(np.int32(1))
assert keyb(np.longlong(1)) == keyb(np.longlong(1))
assert keyb(np.float16(1.1)) == keyb(np.float16(1.1))
assert keyb(np.float32(1.1)) == keyb(np.float32(1.1))
assert keyb(np.float64(1.1)) == keyb(np.float64(1.1))
if hasattr(np, "float128"):
assert keyb(np.float128(1.1)) == keyb(np.float128(1.1))
assert keyb(np.longdouble(1.1)) == keyb(np.longdouble(1.1))
assert keyb(np.complex64(1.1+2.2j)) == keyb(np.complex64(1.1+2.2j))
assert keyb(np.complex128(1.1+2.2j)) == keyb(np.complex128(1.1+2.2j))
if hasattr(np, "complex256"):
assert keyb(np.complex256(1.1+2.2j)) == keyb(np.complex256(1.1+2.2j))
assert keyb(np.clongdouble(1.1+2.2j)) == keyb(np.clongdouble(1.1+2.2j))
@pytest.mark.parametrize("dict_impl", ("immutabledict", "frozendict",
"constantdict",
("immutables", "Map"),
("pyrsistent", "pmap")))
def test_dict_hashing(dict_impl) -> None:
if isinstance(dict_impl, str):
dict_package = dict_impl
dict_class = dict_impl
else:
dict_package = dict_impl[0]
dict_class = dict_impl[1]
pytest.importorskip(dict_package)
import importlib
dc = getattr(importlib.import_module(dict_package), dict_class)
keyb = KeyBuilder()
d = {"a": 1, "b": 2}
assert keyb(dc(d)) == keyb(dc(d))
assert keyb(dc(d)) != keyb(dc({"a": 1, "b": 3}))
assert keyb(dc(d)) == keyb(dc({"b": 2, "a": 1}))
def test_frozenset_hashing() -> None:
keyb = KeyBuilder()
assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([1, 2, 3]))
assert keyb(frozenset([1, 2, 3])) != keyb(frozenset([1, 2, 4]))
assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([3, 2, 1]))
def test_frozenorderedset_hashing() -> None:
pytest.importorskip("orderedsets")
from orderedsets import FrozenOrderedSet
keyb = KeyBuilder()
assert (keyb(FrozenOrderedSet([1, 2, 3]))
== keyb(FrozenOrderedSet([1, 2, 3]))
== keyb(frozenset([1, 2, 3])))
assert keyb(FrozenOrderedSet([1, 2, 3])) != keyb(FrozenOrderedSet([1, 2, 4]))
assert keyb(FrozenOrderedSet([1, 2, 3])) == keyb(FrozenOrderedSet([3, 2, 1]))
def test_ABC_hashing() -> None: # noqa: N802
from abc import ABC, ABCMeta
keyb = KeyBuilder()
class MyABC(ABC): # noqa: B024
pass
assert keyb(MyABC) != keyb(ABC)
with pytest.raises(TypeError):
keyb(MyABC())
with pytest.raises(TypeError):
keyb(ABC())
class MyABC2(MyABC):
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, 42)
assert keyb(MyABC2) != keyb(MyABC)
assert keyb(MyABC2())
class MyABC3(metaclass=ABCMeta): # noqa: B024
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, 42)
assert keyb(MyABC3) != keyb(MyABC) != keyb(MyABC3())
def test_class_hashing() -> None:
keyb = KeyBuilder()
class WithUpdateMethod:
def update_persistent_hash(self, key_hash, key_builder):
# Only called for instances of this class, not for the class itself
key_builder.rec(key_hash, 42)
class TagClass(Tag):
pass
@tag_dataclass
class TagClass2(Tag):
pass
assert keyb(WithUpdateMethod) != keyb(WithUpdateMethod())
assert keyb(TagClass) != keyb(TagClass())
assert keyb(TagClass2) != keyb(TagClass2())
assert keyb(TagClass) != keyb(TagClass2)
assert keyb(TagClass()) != keyb(TagClass2())
assert keyb(TagClass()) == "7b3e4e66503438f6"
assert keyb(TagClass2) == "690b86bbf51aad83"
@tag_dataclass
class TagClass3(Tag):
s: str
assert (keyb(TagClass3("foo")) == "cf1a33652cc75b9c")
def test_dataclass_hashing() -> None:
keyb = KeyBuilder()
@dataclass
class MyDC:
name: str
value: int
assert keyb(MyDC("hi", 1)) == "d1a1079f1c10aa4f"
assert keyb(MyDC("hi", 1)) == keyb(MyDC("hi", 1))
assert keyb(MyDC("hi", 1)) != keyb(MyDC("hi", 2))
@dataclass
class MyDC2:
name: str
value: int
# Class types must be encoded in hash
assert keyb(MyDC2("hi", 1)) != keyb(MyDC("hi", 1))
def test_attrs_hashing() -> None:
attrs = pytest.importorskip("attrs")
keyb = KeyBuilder()
@attrs.define
class MyAttrs:
name: str
value: int
assert (keyb(MyAttrs("hi", 1)) == "5b6c5da60eb2bd0f") # type: ignore[call-arg]
assert keyb(MyAttrs("hi", 1)) == keyb(MyAttrs("hi", 1)) # type: ignore[call-arg]
assert keyb(MyAttrs("hi", 1)) != keyb(MyAttrs("hi", 2)) # type: ignore[call-arg]
@dataclass
class MyDC:
name: str
value: int
assert keyb(MyDC("hi", 1)) != keyb(MyAttrs("hi", 1)) # type: ignore[call-arg]
@attrs.define
class MyAttrs2:
name: str
value: int
# Class types must be encoded in hash
assert (keyb(MyAttrs2("hi", 1)) # type: ignore[call-arg]
!= keyb(MyAttrs("hi", 1))) # type: ignore[call-arg]
def test_datetime_hashing() -> None:
keyb = KeyBuilder()
import datetime
# {{{ date
# No timezone info; date is always naive
assert (keyb(datetime.date(2020, 1, 1))
== keyb(datetime.date(2020, 1, 1))
== "1c866ff10ff0d997")
assert keyb(datetime.date(2020, 1, 1)) != keyb(datetime.date(2020, 1, 2))
# }}}
# {{{ time
# Must distinguish between naive and aware time objects
# Naive time
assert (keyb(datetime.time(12, 0))
== keyb(datetime.time(12, 0))
== keyb(datetime.time(12, 0, 0))
== keyb(datetime.time(12, 0, 0, 0))
== "e523be74ebc6b227")
assert keyb(datetime.time(12, 0)) != keyb(datetime.time(12, 1))
# Aware time
t1 = datetime.time(12, 0, tzinfo=datetime.timezone.utc)
t2 = datetime.time(7, 0,
tzinfo=datetime.timezone(datetime.timedelta(hours=-5)))
t3 = datetime.time(7, 0,
tzinfo=datetime.timezone(datetime.timedelta(hours=-4)))
assert t1 == t2
assert (keyb(t1)
== keyb(t2)
== "2041e7cd5b17b8eb")
assert t1 != t3
assert keyb(t1) != keyb(t3)
# }}}
# {{{ datetime
# must distinguish between naive and aware datetime objects
# Aware datetime
dt1 = datetime.datetime(2020, 1, 1, 12, tzinfo=datetime.timezone.utc)
dt2 = datetime.datetime(2020, 1, 1, 7,
tzinfo=datetime.timezone(datetime.timedelta(hours=-5)))
assert dt1 == dt2
assert (keyb(dt1)
== keyb(dt2)
== "8be96b9e739c7d8c")
dt3 = datetime.datetime(2020, 1, 1, 7,
tzinfo=datetime.timezone(datetime.timedelta(hours=-4)))
assert dt1 != dt3
assert keyb(dt1) != keyb(dt3)
# Naive datetime
dt4 = datetime.datetime(2020, 1, 1, 6) # matches dt1 'naively'
assert dt1 != dt4 # naive and aware datetime objects are never equal
assert keyb(dt1) != keyb(dt4)
assert (keyb(datetime.datetime(2020, 1, 1))
== keyb(datetime.datetime(2020, 1, 1))
== keyb(datetime.datetime(2020, 1, 1, 0, 0, 0, 0))
== "215dbe82add7a55c" # spellchecker: disable-line
)
assert keyb(datetime.datetime(2020, 1, 1)) != keyb(datetime.datetime(2020, 1, 2))
assert (keyb(datetime.datetime(2020, 1, 1))
!= keyb(datetime.datetime(2020, 1, 1, tzinfo=datetime.timezone.utc)))
# }}}
# {{{ timezone
tz1 = datetime.timezone(datetime.timedelta(hours=-4))
tz2 = datetime.timezone(datetime.timedelta(hours=0))
tz3 = datetime.timezone.utc
assert tz1 != tz2
assert keyb(tz1) != keyb(tz2)
assert tz1 != tz3
assert keyb(tz1) != keyb(tz3)
assert tz2 == tz3
assert (keyb(tz2)
== keyb(tz3)
== "5e1d46ab778c7ccf")
# }}}
def test_xdg_cache_home() -> None:
import os
xdg_dir = "tmpdir_pytools_xdg_test"
assert not os.path.exists(xdg_dir)
old_xdg_cache_home = os.environ.get("XDG_CACHE_HOME")
try:
os.environ["XDG_CACHE_HOME"] = xdg_dir
PersistentDict("pytools-test", safe_sync=False)
assert os.path.exists(xdg_dir)
finally:
if old_xdg_cache_home is not None:
os.environ["XDG_CACHE_HOME"] = old_xdg_cache_home
else:
del os.environ["XDG_CACHE_HOME"]
shutil.rmtree(xdg_dir)
def test_speed():
import time
tmpdir = tempfile.mkdtemp()
pdict = WriteOncePersistentDict("pytools-test", container_dir=tmpdir,
safe_sync=False)
start = time.time()
for i in range(10000):
pdict[i] = i
end = time.time()
print("persistent dict write time: ", end-start)
start = time.time()
for _ in range(5):
for i in range(10000):
pdict[i]
end = time.time()
print("persistent dict read time: ", end-start)
shutil.rmtree(tmpdir)
def test_size():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir, safe_sync=False)
for i in range(10000):
pdict[f"foobarbazfoobbb{i}"] = i
size = pdict.nbytes()
print("sqlite size: ", size/1024/1024, " MByte")
assert 1024*1024//2 < size < 2*1024*1024
finally:
shutil.rmtree(tmpdir)
def test_len():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir, safe_sync=False)
assert len(pdict) == 0
for i in range(10000):
pdict[i] = i
assert len(pdict) == 10000
pdict.clear()
assert len(pdict) == 0
finally:
shutil.rmtree(tmpdir)
def test_repr():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir, safe_sync=False)
assert repr(pdict)[:15] == "PersistentDict("
finally:
shutil.rmtree(tmpdir)
def test_keys_values_items():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir, safe_sync=False)
for i in range(10000):
pdict[i] = i
# This also tests deterministic iteration order
assert len(pdict) == 10000 == len(set(pdict))
assert list(pdict.keys()) == list(range(10000))
assert list(pdict.values()) == list(range(10000))
assert list(pdict.items()) == list(zip(pdict, range(10000), strict=True))
assert ([k for k in pdict.keys()] # noqa: C416
== list(pdict.keys())
== list(pdict)
== [k for k in pdict]) # noqa: C416
finally:
shutil.rmtree(tmpdir)
def global_fun():
pass
def global_fun2():
pass
def test_hash_function() -> None:
keyb = KeyBuilder()
# {{{ global functions
assert keyb(global_fun) == keyb(global_fun) == "79efd03f9a38ed77"
assert keyb(global_fun) != keyb(global_fun2)
# }}}
# {{{ closures
def get_fun(x):
def add_x(y):
return x + y
return add_x
f1 = get_fun(1)
f11 = get_fun(1)
f2 = get_fun(2)
fa = get_fun
fb = get_fun
assert fa == fb
assert keyb(fa) == keyb(fb)
assert f1 != f2
assert keyb(f1) != keyb(f2)
# FIXME: inconsistency!
assert f1 != f11
assert hash(f1) != hash(f11)
assert keyb(f1) == keyb(f11)
# }}}
# {{{ local functions
def local_fun():
pass
def local_fun2():
pass
assert keyb(local_fun) == keyb(local_fun) == "adc92e690b62dc2b"
assert keyb(local_fun) != keyb(local_fun2)
# }}}
# {{{ methods
class C1:
def method(self):
pass
class C2:
def method(self):
pass
assert keyb(C1.method) == keyb(C1.method) == "af19e056ad7749c4"
assert keyb(C1.method) != keyb(C2.method)
# }}}
# {{{ basic concurrency tests
def _conc_fn(tmpdir: str | None = None,
pdict: PersistentDict[int, int] | None = None) -> None:
import time
assert (pdict is None) ^ (tmpdir is None)
if pdict is None:
pdict = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
n = 10000
s = 0
start = time.time()
for i in range(n):
if i % 1000 == 0:
print(f"i={i}")
if isinstance(pdict, WriteOncePersistentDict):
try:
pdict[i] = i
except ReadOnlyEntryError:
pass
else:
pdict[i] = i
try:
s += pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass
if not isinstance(pdict, WriteOncePersistentDict):
try:
del pdict[i]
except NoSuchEntryError:
# Someone else already deleted the entry
pass
end = time.time()
print(f"PersistentDict: time taken to write {n} entries to "
f"{pdict.filename}: {end-start} s={s}")
def test_concurrency_processes() -> None:
from multiprocessing import Process
tmpdir = "_tmp_proc/" # must be the same across all processes in this test
try:
# multiprocessing needs to pickle function arguments, so we can't pass
# the PersistentDict object (which is unpicklable) directly.
p = [Process(target=_conc_fn, args=(tmpdir, None)) for _ in range(4)]
for pp in p:
pp.start()
for pp in p:
pp.join()
assert all(pp.exitcode == 0 for pp in p), [pp.exitcode for pp in p]
finally:
shutil.rmtree(tmpdir)
from threading import Thread
class RaisingThread(Thread):
def run(self) -> None:
self._exc = None
try:
super().run()
except Exception as e:
self._exc = e
def join(self, timeout: float | None = None) -> None:
super().join(timeout=timeout)
if self._exc:
raise self._exc
def test_concurrency_threads() -> None:
tmpdir = "_tmp_threads/" # must be the same across all threads in this test
try:
# Share this pdict object among all threads to test thread safety
pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
container_dir=tmpdir,
safe_sync=False)
t = [RaisingThread(target=_conc_fn, args=(None, pdict)) for _ in range(4)]
for tt in t:
tt.start()
for tt in t:
tt.join()
# Threads will raise in join() if they encountered an exception
finally:
shutil.rmtree(tmpdir)
try:
# Share this pdict object among all threads to test thread safety
pdict2: WriteOncePersistentDict[int, int] = WriteOncePersistentDict(
"pytools-test",
container_dir=tmpdir,
safe_sync=False)
t = [RaisingThread(target=_conc_fn, args=(None, pdict2)) for _ in range(4)]
for tt in t:
tt.start()
for tt in t:
tt.join()
# Threads will raise in join() if they encountered an exception
finally:
shutil.rmtree(tmpdir)
# }}}
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
pytest.main([__file__])
from __future__ import absolute_import, division, with_statement
from __future__ import annotations
import sys
......
from __future__ import annotations
__copyright__ = "Copyright (C) 2009-2021 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import logging
import sys
from dataclasses import dataclass
import pytest
from pytools import Record
from pytools.tag import tag_dataclass
logger = logging.getLogger(__name__)
def test_memoize_method_clear():
from pytools import memoize_method
class SomeClass:
def __init__(self):
self.run_count = 0
@memoize_method
def f(self):
self.run_count += 1
return 17
sc = SomeClass()
sc.f()
sc.f()
assert sc.run_count == 1
sc.f.clear_cache(sc)
def test_keyed_memoize_method_with_uncached():
from pytools import keyed_memoize_method
class SomeClass:
def __init__(self):
self.run_count = 0
@keyed_memoize_method(key=lambda x, y, z: x)
def f(self, x, y, z):
del x, y, z
self.run_count += 1
return 17
sc = SomeClass()
sc.f(17, 18, z=19)
sc.f(17, 19, z=20)
assert sc.run_count == 1
sc.f(18, 19, z=20)
assert sc.run_count == 2
sc.f.clear_cache(sc)
def test_memoize_in():
from pytools import memoize_in
class SomeClass:
def __init__(self):
self.run_count = 0
def f(self):
@memoize_in(self, (SomeClass.f,))
def inner(x):
self.run_count += 1
return 2*x
inner(5)
inner(5)
sc = SomeClass()
sc.f()
assert sc.run_count == 1
def test_p_convergence_verifier():
pytest.importorskip("numpy")
from pytools.convergence import PConvergenceVerifier
pconv_verifier = PConvergenceVerifier()
for order in [2, 3, 4, 5]:
pconv_verifier.add_data_point(order, 0.1**order)
pconv_verifier()
pconv_verifier = PConvergenceVerifier()
for order in [2, 3, 4, 5]:
pconv_verifier.add_data_point(order, 0.5**order)
pconv_verifier()
pconv_verifier = PConvergenceVerifier()
for order in [2, 3, 4, 5]:
pconv_verifier.add_data_point(order, 2)
with pytest.raises(AssertionError):
pconv_verifier()
def test_memoize():
from pytools import memoize
count = [0]
@memoize
def f(i, j):
count[0] += 1
return i + j
assert f(1, 2) == 3
assert f(1, 2) == 3
assert count[0] == 1
def test_memoize_with_kwargs():
from pytools import memoize
count = [0]
@memoize(use_kwargs=True)
def f(i, j=1):
count[0] += 1
return i + j
assert f(1) == 2
assert f(1, 2) == 3
assert f(2, j=3) == 5
assert count[0] == 3
assert f(1) == 2
assert f(1, 2) == 3
assert f(2, j=3) == 5
assert count[0] == 3
def test_memoize_keyfunc():
from pytools import memoize
count = [0]
@memoize(key=lambda i, j=(1,): (i, len(j)))
def f(i, j=(1,)):
count[0] += 1
return i + len(j)
assert f(1) == 2
assert f(1, [2]) == 2
assert f(2, j=[2, 3]) == 4
assert count[0] == 2
assert f(1) == 2
assert f(1, (2,)) == 2
assert f(2, j=(2, 3)) == 4
assert count[0] == 2
def test_memoize_frozen() -> None:
from pytools import memoize_method
# {{{ check frozen dataclass
@dataclass(frozen=True)
class FrozenDataclass:
value: int
@memoize_method
def double_value(self):
return 2 * self.value
c0 = FrozenDataclass(10)
assert c0.double_value() == 20
c0.double_value.clear_cache(c0) # type: ignore[attr-defined]
# }}}
# {{{ check class with no setattr
class FrozenClass:
value: int
def __init__(self, value):
object.__setattr__(self, "value", value)
def __setattr__(self, key, value):
raise AttributeError(f"cannot set attribute {key}")
@memoize_method
def double_value(self):
return 2 * self.value
c1 = FrozenClass(10)
assert c1.double_value() == 20
c1.double_value.clear_cache(c1) # type: ignore[attr-defined]
# }}}
@pytest.mark.parametrize("dims", [2, 3])
def test_spatial_btree(dims, do_plot=False):
pytest.importorskip("numpy")
import numpy as np
rng = np.random.default_rng()
nparticles = 2000
x = -1 + 2*rng.uniform(size=(dims, nparticles))
x = np.sign(x)*np.abs(x)**1.9
x = (1.4 + x) % 2 - 1
bl = np.min(x, axis=-1)
tr = np.max(x, axis=-1)
print(bl, tr)
from pytools.spatial_btree import SpatialBinaryTreeBucket
tree = SpatialBinaryTreeBucket(bl, tr, max_elements_per_box=10)
for i in range(nparticles):
tree.insert(i, (x[:, i], x[:, i]))
if do_plot:
import matplotlib.pyplot as pt
pt.gca().set_aspect("equal")
pt.plot(x[0], x[1], "x")
tree.plot(fill=None)
pt.show()
def test_generate_numbered_unique_names():
from pytools import generate_numbered_unique_names
gen = generate_numbered_unique_names("a")
assert next(gen) == (0, "a")
assert next(gen) == (1, "a_0")
gen = generate_numbered_unique_names("b", 6)
assert next(gen) == (7, "b_6")
def test_cartesian_product():
from pytools import cartesian_product
expected_outputs = [
(0, 2, 4),
(0, 2, 5),
(0, 3, 4),
(0, 3, 5),
(1, 2, 4),
(1, 2, 5),
(1, 3, 4),
(1, 3, 5),
]
for i, output in enumerate(cartesian_product([0, 1], [2, 3], [4, 5])):
assert output == expected_outputs[i]
def test_find_module_git_revision():
import pytools
print(pytools.find_module_git_revision(pytools.__file__, n_levels_up=1))
def test_reshaped_view():
import pytools
np = pytest.importorskip("numpy")
a = np.zeros((10, 2))
b = a.T
c = pytools.reshaped_view(a, -1)
assert c.shape == (20,)
with pytest.raises(AttributeError):
pytools.reshaped_view(b, -1)
def test_processlogger():
logging.basicConfig(level=logging.INFO)
from pytools import ProcessLogger
plog = ProcessLogger(logger, "testing the process logger",
long_threshold_seconds=0.01)
from time import sleep
with plog:
sleep(0.3)
def test_table():
import math
from pytools import Table
tbl = Table()
tbl.add_row(("i", "i^2", "i^3", "sqrt(i)"))
for i in range(8):
tbl.add_row((i, i ** 2, i ** 3, math.sqrt(i)))
print(tbl)
print()
print(tbl.latex())
# {{{ test merging
from pytools import merge_tables
tbl = merge_tables(tbl, tbl, tbl, skip_columns=(0,))
print(tbl.github_markdown())
# }}}
def test_eoc():
np = pytest.importorskip("numpy")
from pytools.convergence import EOCRecorder
eoc = EOCRecorder()
# {{{ test pretty_print
for i in range(1, 8):
eoc.add_data_point(1.0 / i, 10 ** (-i))
p = eoc.pretty_print()
print(p)
print()
p = eoc.pretty_print(
abscissa_format="%.5e",
error_format="%.5e",
eoc_format="%5.2f")
print(p)
# }}}
# {{{ test merging
from pytools.convergence import stringify_eocs
p = stringify_eocs(eoc, eoc, eoc, names=("First", "Second", "Third"))
print(p)
# }}}
# {{{ test invalid inputs
eoc = EOCRecorder()
# scalar inputs are fine
eoc.add_data_point(1, 1)
eoc.add_data_point(1.0, 1.0)
eoc.add_data_point(np.float32(1.0), 1.0)
eoc.add_data_point(np.array(3), 1.0)
eoc.add_data_point(1.0, np.array(3))
# non-scalar inputs are not fine though
with pytest.raises(TypeError):
eoc.add_data_point(np.array([3]), 1.0)
with pytest.raises(TypeError):
eoc.add_data_point(1.0, np.array([3]))
# }}}
def test_natsorted():
from pytools import natorder, natsorted
assert natorder("1.001") < natorder("1.01")
assert natsorted(["x10", "x1", "x9"]) == ["x1", "x9", "x10"]
assert natsorted(map(str, range(100))) == list(map(str, range(100)))
assert natsorted(["x10", "x1", "x9"], reverse=True) == ["x10", "x9", "x1"]
assert natsorted([10, 1, 9], key=lambda d: f"x{d}") == [1, 9, 10]
# {{{ object array iteration behavior
class FakeArray:
nopes = 0
def __len__(self):
FakeArray.nopes += 1
return 10
def __getitem__(self, idx):
FakeArray.nopes += 1
if idx > 10:
raise IndexError
def test_make_obj_array_iteration():
pytest.importorskip("numpy")
from pytools.obj_array import make_obj_array
make_obj_array([FakeArray()])
assert FakeArray.nopes == 0, FakeArray.nopes
# }}}
# {{{ test obj array vectorization and decorators
def test_obj_array_vectorize(c=1):
np = pytest.importorskip("numpy")
la = pytest.importorskip("numpy.linalg")
# {{{ functions
import pytools.obj_array as obj
def add_one(ary):
assert ary.dtype.char != "O"
return ary + c
def two_add_one(x, y):
assert x.dtype.char != "O" and y.dtype.char != "O"
return x * y + c
@obj.obj_array_vectorized
def vectorized_add_one(ary):
assert ary.dtype.char != "O"
return ary + c
@obj.obj_array_vectorized_n_args
def vectorized_two_add_one(x, y):
assert x.dtype.char != "O" and y.dtype.char != "O"
return x * y + c
class Adder:
def __init__(self, c):
self.c = c
def add(self, ary):
assert ary.dtype.char != "O"
return ary + self.c
@obj.obj_array_vectorized_n_args
def vectorized_add(self, ary):
assert ary.dtype.char != "O"
return ary + self.c
adder = Adder(c)
# }}}
# {{{ check
scalar_ary = np.ones(42, dtype=np.float64)
object_ary = obj.make_obj_array([scalar_ary, scalar_ary, scalar_ary])
for func, vectorizer, nargs in [
(add_one, obj.obj_array_vectorize, 1),
(two_add_one, obj.obj_array_vectorize_n_args, 2),
(adder.add, obj.obj_array_vectorize, 1),
]:
input_ary = [scalar_ary] * nargs
result = vectorizer(func, *input_ary)
error = la.norm(result - c - 1)
print(error)
input_ary = [object_ary] * nargs
result = vectorizer(func, *input_ary)
error = 0
for func, nargs in [
(vectorized_add_one, 1),
(vectorized_two_add_one, 2),
(adder.vectorized_add, 1),
]:
input_ary = [scalar_ary] * nargs
result = func(*input_ary)
input_ary = [object_ary] * nargs
result = func(*input_ary)
# }}}
# }}}
def test_tag() -> None:
from pytools.tag import (
NonUniqueTagError,
Tag,
Taggable,
UniqueTag,
check_tag_uniqueness,
)
# Need a subclass that defines the copy function in order to test.
@tag_dataclass
class TaggableWithCopy(Taggable):
tags: frozenset[Tag]
def _with_new_tags(self, tags):
return TaggableWithCopy(tags)
class FairRibbon(Tag):
pass
class BlueRibbon(FairRibbon):
pass
class RedRibbon(FairRibbon):
pass
class ShowRibbon(FairRibbon, UniqueTag):
pass
class BestInShowRibbon(ShowRibbon):
pass
class ReserveBestInShowRibbon(ShowRibbon):
pass
class BestInClassRibbon(FairRibbon, UniqueTag):
pass
best_in_show_ribbon = BestInShowRibbon()
reserve_best_in_show_ribbon = ReserveBestInShowRibbon()
blue_ribbon = BlueRibbon()
red_ribbon = RedRibbon()
best_in_class_ribbon = BestInClassRibbon()
# Test that input processing fails if there are multiple instances
# of the same UniqueTag subclass
with pytest.raises(NonUniqueTagError):
check_tag_uniqueness(frozenset((
best_in_show_ribbon,
reserve_best_in_show_ribbon, blue_ribbon, red_ribbon)))
# Test that input processing fails if any of the tags are not
# a subclass of Tag
with pytest.raises(TypeError):
check_tag_uniqueness(frozenset((
"I am not a tag", best_in_show_ribbon, # type: ignore[arg-type]
blue_ribbon, red_ribbon)))
# Test that instantiation succeeds if there are multiple instances
# Tag subclasses.
t1 = TaggableWithCopy(frozenset([reserve_best_in_show_ribbon, blue_ribbon,
red_ribbon]))
assert t1.tags == frozenset((reserve_best_in_show_ribbon, red_ribbon,
blue_ribbon))
# Test that instantiation succeeds if there are multiple instances
# of UniqueTag of different subclasses.
t1 = TaggableWithCopy(frozenset([reserve_best_in_show_ribbon,
best_in_class_ribbon, blue_ribbon,
blue_ribbon]))
assert t1.tags == frozenset((reserve_best_in_show_ribbon, best_in_class_ribbon,
blue_ribbon))
# Test tagged() function
t2 = t1.tagged(red_ribbon)
print(t2.tags)
assert t2.tags == frozenset((reserve_best_in_show_ribbon, best_in_class_ribbon,
blue_ribbon, red_ribbon))
# Test that tagged() fails if a UniqueTag of the same subclass
# is already present
with pytest.raises(NonUniqueTagError):
t1.tagged(best_in_show_ribbon)
# Test that tagged() fails if tags are not a FrozenSet of Tags
with pytest.raises(TypeError):
t1.tagged(tags=frozenset((1,))) # type: ignore[arg-type]
# Test without_tags() function
t4 = t2.without_tags(red_ribbon)
assert t4.tags == t1.tags
# Test that without_tags() fails if the tag is not present.
with pytest.raises(ValueError):
t4.without_tags(red_ribbon)
# Test DottedName comparison
from pytools.tag import DottedName
assert FairRibbon() == FairRibbon()
assert (FairRibbon().tag_name
== FairRibbon().tag_name
== DottedName(("pytools", "test", "test_pytools", "FairRibbon")))
assert FairRibbon() != BlueRibbon()
assert FairRibbon().tag_name != BlueRibbon().tag_name
def test_unordered_hash():
import hashlib
import random
# FIXME: Use randbytes once >=3.9 is OK
lst = [bytes([random.randrange(256) for _ in range(20)])
for _ in range(200)]
lorig = lst[:]
random.shuffle(lst)
from pytools import unordered_hash
assert (unordered_hash(hashlib.sha256(), lorig).digest()
== unordered_hash(hashlib.sha256(), lst).digest())
assert (unordered_hash(hashlib.sha256(), lorig).digest()
== unordered_hash(hashlib.sha256(), lorig).digest())
assert (unordered_hash(hashlib.sha256(), lorig).digest()
!= unordered_hash(hashlib.sha256(), lorig[:-1]).digest())
lst[0] = b"aksdjfla;sdfjafd"
assert (unordered_hash(hashlib.sha256(), lorig).digest()
!= unordered_hash(hashlib.sha256(), lst).digest())
# {{{ sphere sampling
@pytest.mark.parametrize("sampling", [
"equidistant", "fibonacci", "fibonacci_min", "fibonacci_avg",
])
def test_sphere_sampling(sampling, visualize=False):
from functools import partial
from pytools import sphere_sample_equidistant, sphere_sample_fibonacci
npoints = 128
radius = 1.5
if sampling == "equidistant":
sampling_func = partial(sphere_sample_equidistant, r=radius)
elif sampling == "fibonacci":
sampling_func = partial(
sphere_sample_fibonacci, r=radius, optimize=None)
elif sampling == "fibonacci_min":
sampling_func = partial(
sphere_sample_fibonacci, r=radius, optimize="minimum")
elif sampling == "fibonacci_avg":
sampling_func = partial(
sphere_sample_fibonacci, r=radius, optimize="average")
else:
raise ValueError(f"unknown sampling method: '{sampling}'")
np = pytest.importorskip("numpy")
points = sampling_func(npoints)
assert np.all(np.linalg.norm(points, axis=0) < radius + 1.0e-15)
if not visualize:
return
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 10), dpi=300)
ax = fig.add_subplot(111, projection="3d")
import matplotlib.tri as mtri
theta = np.arctan2(np.sqrt(points[0]**2 + points[1]**2), points[2])
phi = np.arctan2(points[1], points[0])
triangles = mtri.Triangulation(theta, phi)
ax.plot_trisurf(points[0], points[1], points[2], triangles=triangles.triangles)
ax.set_xlim([-radius, radius])
ax.set_ylim([-radius, radius])
ax.set_zlim([-radius, radius])
ax.margins(0.05, 0.05, 0.05)
# plt.show()
fig.savefig(f"sphere_sampling_{sampling}")
plt.close(fig)
# }}}
def test_unique_name_gen_conflicting_ok():
from pytools import UniqueNameGenerator
ung = UniqueNameGenerator()
ung.add_names({"a", "b", "c"})
with pytest.raises(ValueError):
ung.add_names({"a"})
ung.add_names({"a", "b", "c"}, conflicting_ok=True)
def test_strtobool():
from pytools import strtobool
assert strtobool("true") is True
assert strtobool("tRuE") is True
assert strtobool("1") is True
assert strtobool("t") is True
assert strtobool("on") is True
assert strtobool("false") is False
assert strtobool("FaLse") is False
assert strtobool("0") is False
assert strtobool("f") is False
assert strtobool("off") is False
with pytest.raises(ValueError):
strtobool("tru")
strtobool("fal")
strtobool("xxx")
strtobool(".")
assert strtobool(None, False) is False
def test_to_identifier() -> None:
from pytools import to_identifier
assert to_identifier("_a_123_") == "_a_123_"
assert to_identifier("a_123") == "a_123"
assert to_identifier("a 123") == "a123"
assert to_identifier("123") == "_123"
assert to_identifier("_123") == "_123"
assert to_identifier("123A") == "_123A"
assert to_identifier("") == "_"
assert not "a 123".isidentifier()
assert to_identifier("a 123").isidentifier()
assert to_identifier("123").isidentifier()
assert to_identifier("").isidentifier()
def test_typedump():
from pytools import typedump
assert typedump("") == "str"
assert typedump("abcdefg") == "str"
assert typedump(5) == "int"
assert typedump((5.0, 4)) == "tuple(float,int)"
assert typedump([5, 4]) == "list(int,int)"
assert typedump({5, 4}) == "set(int,int)"
assert typedump(frozenset((1, 2, 3))) == "frozenset(int,int,int)"
assert typedump([5, 4, 3, 2, 1]) == "list(int,int,int,int,int)"
assert typedump([5, 4, 3, 2, 1, 0]) == "list(int,int,int,int,int,...)"
assert typedump([5, 4, 3, 2, 1, 0], max_seq=6) == "list(int,int,int,int,int,int)"
assert typedump({5: 42, 7: 43}) == "{'5': int, '7': int}"
class C:
class D:
pass
assert typedump(C()) == "pytools.test.test_pytools.test_typedump.<locals>.C"
assert typedump(C.D()) == "pytools.test.test_pytools.test_typedump.<locals>.C.D"
assert typedump(C.D(), fully_qualified_name=False) == "D"
from pytools.datatable import DataTable
t = DataTable(column_names=[])
assert typedump(t) == "pytools.datatable.DataTable()"
assert typedump(t, special_handlers={type(t): lambda x: "foo"}) == "foo"
def test_unique():
from pytools import unique, unique_difference, unique_intersection, unique_union
assert list(unique([1, 2, 1])) == [1, 2]
assert tuple(unique((1, 2, 1))) == (1, 2)
assert list(range(1000)) == list(unique(range(1000)))
assert list(unique(list(range(1000)) + list(range(1000)))) == list(range(1000))
# Also test strings since their ordering would be thrown off by
# set-based 'unique' implementations.
assert list(unique(["a", "b", "a"])) == ["a", "b"]
assert tuple(unique(("a", "b", "a"))) == ("a", "b")
assert list(unique_difference(["a", "b", "c"], ["b", "c", "d"])) == ["a"]
assert list(unique_difference(["a", "b", "c"], ["a", "b", "c", "d"])) == []
assert list(unique_difference(["a", "b", "c"], ["a"], ["b"], ["c"])) == []
assert list(unique_intersection(["a", "b", "a"], ["b", "c", "a"])) == ["a", "b"]
assert list(unique_intersection(["a", "b", "a"], ["d", "c", "e"])) == []
assert list(unique_union(["a", "b", "a"], ["b", "c", "b"])) == ["a", "b", "c"]
assert list(unique_union(
["a", "b", "a"], ["b", "c", "b"], ["c", "d", "c"])) == ["a", "b", "c", "d"]
assert list(unique(["a", "b", "a"])) == \
list(unique_union(["a", "b", "a"])) == ["a", "b"]
assert list(unique_intersection()) == []
assert list(unique_difference()) == []
assert list(unique_union()) == []
# This class must be defined globally to be picklable
class SimpleRecord(Record):
pass
def test_record():
r = SimpleRecord(c=3, b=2, a=1)
assert r.a == 1
assert r.b == 2
assert r.c == 3
# Fields are sorted alphabetically in records
assert str(r) == "SimpleRecord(a=1, b=2, c=3)"
# Unregistered fields are (silently) ignored for printing
r.f = 6
assert str(r) == "SimpleRecord(a=1, b=2, c=3)"
# Registered fields are printed
r.register_fields({"d", "e"})
assert str(r) == "SimpleRecord(a=1, b=2, c=3)"
r.d = 4
r.e = 5
assert str(r) == "SimpleRecord(a=1, b=2, c=3, d=4, e=5)"
with pytest.raises(AttributeError):
r.ff # noqa: B018
# Test pickling
import pickle
r_pickled = pickle.loads(pickle.dumps(r))
assert r == r_pickled
# }}}
# {{{ __slots__, __dict__, __weakref__ handling
class RecordWithEmptySlots(Record):
__slots__ = []
assert hasattr(RecordWithEmptySlots(), "__slots__")
assert not hasattr(RecordWithEmptySlots(), "__dict__")
assert not hasattr(RecordWithEmptySlots(), "__weakref__")
class RecordWithUnsetSlots(Record):
pass
assert hasattr(RecordWithUnsetSlots(), "__slots__")
assert hasattr(RecordWithUnsetSlots(), "__dict__")
assert hasattr(RecordWithUnsetSlots(), "__weakref__")
from pytools import ImmutableRecord
class ImmutableRecordWithEmptySlots(ImmutableRecord):
__slots__ = []
assert hasattr(ImmutableRecordWithEmptySlots(), "__slots__")
assert hasattr(ImmutableRecordWithEmptySlots(), "__dict__")
assert hasattr(ImmutableRecordWithEmptySlots(), "__weakref__")
class ImmutableRecordWithUnsetSlots(ImmutableRecord):
pass
assert hasattr(ImmutableRecordWithUnsetSlots(), "__slots__")
assert hasattr(ImmutableRecordWithUnsetSlots(), "__dict__")
assert hasattr(ImmutableRecordWithUnsetSlots(), "__weakref__")
# }}}
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])
VERSION = (2019, 1)
VERSION_STATUS = ""
VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS
from __future__ import annotations
import re
from importlib import metadata
VERSION_TEXT = metadata.version("pytools")
_match = re.match(r"^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT)
assert _match is not None
VERSION_STATUS = _match.group(2)
VERSION = tuple(int(nr) for nr in _match.group(1).split("."))
#! /bin/bash
set -ex
python -m mypy pytools
python -m mypy --strict --follow-imports=silent \
pytools/datatable.py \
pytools/graph.py \
pytools/persistent_dict.py \
pytools/prefork.py \
pytools/tag.py \
#!/bin/bash
set -o errexit -o nounset
ci_support="https://gitlab.tiker.net/inducer/ci-support/raw/main"
if [[ ! -f .pylintrc.yml ]]; then
curl -o .pylintrc.yml "${ci_support}/.pylintrc-default.yml"
fi
if [[ ! -f .run-pylint.py ]]; then
curl -L -o .run-pylint.py "${ci_support}/run-pylint.py"
fi
PYLINT_RUNNER_ARGS="--jobs=4 --yaml-rcfile=.pylintrc.yml"
if [[ -f .pylintrc-local.yml ]]; then
PYLINT_RUNNER_ARGS+=" --yaml-rcfile=.pylintrc-local.yml"
fi
PYTHONWARNINGS=ignore python .run-pylint.py $PYLINT_RUNNER_ARGS pytools examples "$@"
[flake8]
ignore = E126,E127,E128,E123,E226,E241,E242,E265,E402,W503,E731
max-line-length=85
exclude=pytools/arithmetic_container.py,pytools/decorator.py
[wheel]
universal = 1
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from setuptools import setup
ver_dic = {}
version_file = open("pytools/version.py")
try:
version_file_contents = version_file.read()
finally:
version_file.close()
exec(compile(version_file_contents, "pytools/version.py", 'exec'), ver_dic)
setup(name="pytools",
version=ver_dic["VERSION_TEXT"],
description="A collection of tools for Python",
long_description=open("README.rst", "r").read(),
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Other Audience',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Programming Language :: Python',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Information Analysis',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Visualization',
'Topic :: Software Development :: Libraries',
'Topic :: Utilities',
],
install_requires=[
"decorator>=3.2.0",
"appdirs>=1.4.0",
"six>=1.8.0",
"numpy>=1.6.0",
],
author="Andreas Kloeckner",
url="http://pypi.python.org/pypi/pytools",
author_email="inform@tiker.net",
license="MIT",
packages=["pytools"])
from __future__ import absolute_import, division, with_statement
import shutil
import sys # noqa
import tempfile
import pytest
from six.moves import range, zip
from pytools.persistent_dict import (CollisionWarning, NoSuchEntryError,
PersistentDict, ReadOnlyEntryError, WriteOncePersistentDict)
# {{{ type for testing
class PDictTestingKeyOrValue(object):
def __init__(self, val, hash_key=None):
self.val = val
if hash_key is None:
hash_key = val
self.hash_key = hash_key
def __getstate__(self):
return {"val": self.val, "hash_key": self.hash_key}
def __eq__(self, other):
return self.val == other.val
def __ne__(self, other):
return not self.__eq__(other)
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, self.hash_key)
def __repr__(self):
return "PDictTestingKeyOrValue(val=%r,hash_key=%r)" % (
(self.val, self.hash_key))
__str__ = __repr__
# }}}
def test_persistent_dict_storage_and_lookup():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir)
from random import randrange
def rand_str(n=20):
return "".join(
chr(65+randrange(26))
for i in range(n))
keys = [(randrange(2000), rand_str(), None) for i in range(20)]
values = [randrange(2000) for i in range(20)]
d = dict(list(zip(keys, values)))
# {{{ check lookup
for k, v in zip(keys, values):
pdict[k] = v
for k, v in d.items():
assert d[k] == pdict[k]
# }}}
# {{{ check updating
for k, v in zip(keys, values):
pdict[k] = v + 1
for k, v in d.items():
assert d[k] + 1 == pdict[k]
# }}}
# {{{ check store_if_not_present
for k, v in zip(keys, values):
pdict.store_if_not_present(k, d[k] + 2)
for k, v in d.items():
assert d[k] + 1 == pdict[k]
pdict.store_if_not_present(2001, 2001)
assert pdict[2001] == 2001
# }}}
# check not found
with pytest.raises(NoSuchEntryError):
pdict.fetch(3000)
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_deletion():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir)
pdict[0] = 0
del pdict[0]
with pytest.raises(NoSuchEntryError):
pdict.fetch(0)
with pytest.raises(NoSuchEntryError):
del pdict[1]
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_synchronization():
try:
tmpdir = tempfile.mkdtemp()
pdict1 = PersistentDict("pytools-test", container_dir=tmpdir)
pdict2 = PersistentDict("pytools-test", container_dir=tmpdir)
# check lookup
pdict1[0] = 1
assert pdict2[0] == 1
# check updating
pdict1[0] = 2
assert pdict2[0] == 2
# check deletion
del pdict1[0]
with pytest.raises(NoSuchEntryError):
pdict2.fetch(0)
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_cache_collisions():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir)
key1 = PDictTestingKeyOrValue(1, hash_key=0)
key2 = PDictTestingKeyOrValue(2, hash_key=0)
pdict[key1] = 1
# check lookup
with pytest.warns(CollisionWarning):
with pytest.raises(NoSuchEntryError):
pdict.fetch(key2)
# check deletion
with pytest.warns(CollisionWarning):
with pytest.raises(NoSuchEntryError):
del pdict[key2]
# check presence after deletion
assert pdict[key1] == 1
# check store_if_not_present
pdict.store_if_not_present(key2, 2)
assert pdict[key1] == 1
finally:
shutil.rmtree(tmpdir)
def test_persistent_dict_clear():
try:
tmpdir = tempfile.mkdtemp()
pdict = PersistentDict("pytools-test", container_dir=tmpdir)
pdict[0] = 1
pdict.fetch(0)
pdict.clear()
with pytest.raises(NoSuchEntryError):
pdict.fetch(0)
finally:
shutil.rmtree(tmpdir)
@pytest.mark.parametrize("in_mem_cache_size", (0, 256))
def test_write_once_persistent_dict_storage_and_lookup(in_mem_cache_size):
try:
tmpdir = tempfile.mkdtemp()
pdict = WriteOncePersistentDict(
"pytools-test", container_dir=tmpdir,
in_mem_cache_size=in_mem_cache_size)
# check lookup
pdict[0] = 1
assert pdict[0] == 1
# do two lookups to test the cache
assert pdict[0] == 1
# check updating
with pytest.raises(ReadOnlyEntryError):
pdict[0] = 2
# check not found
with pytest.raises(NoSuchEntryError):
pdict.fetch(1)
# check store_if_not_present
pdict.store_if_not_present(0, 2)
assert pdict[0] == 1
pdict.store_if_not_present(1, 1)
assert pdict[1] == 1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_lru_policy():
try:
tmpdir = tempfile.mkdtemp()
pdict = WriteOncePersistentDict(
"pytools-test", container_dir=tmpdir, in_mem_cache_size=3)
pdict[1] = PDictTestingKeyOrValue(1)
pdict[2] = PDictTestingKeyOrValue(2)
pdict[3] = PDictTestingKeyOrValue(3)
pdict[4] = PDictTestingKeyOrValue(4)
val1 = pdict.fetch(1)
assert pdict.fetch(1) is val1
pdict.fetch(2)
assert pdict.fetch(1) is val1
pdict.fetch(2)
pdict.fetch(3)
assert pdict.fetch(1) is val1
pdict.fetch(2)
pdict.fetch(3)
pdict.fetch(2)
assert pdict.fetch(1) is val1
pdict.fetch(2)
pdict.fetch(3)
pdict.fetch(4)
assert pdict.fetch(1) is not val1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_synchronization():
try:
tmpdir = tempfile.mkdtemp()
pdict1 = WriteOncePersistentDict("pytools-test", container_dir=tmpdir)
pdict2 = WriteOncePersistentDict("pytools-test", container_dir=tmpdir)
# check lookup
pdict1[1] = 0
assert pdict2[1] == 0
# check updating
with pytest.raises(ReadOnlyEntryError):
pdict2[1] = 1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_cache_collisions():
try:
tmpdir = tempfile.mkdtemp()
pdict = WriteOncePersistentDict("pytools-test", container_dir=tmpdir)
key1 = PDictTestingKeyOrValue(1, hash_key=0)
key2 = PDictTestingKeyOrValue(2, hash_key=0)
pdict[key1] = 1
# check lookup
with pytest.warns(CollisionWarning):
with pytest.raises(NoSuchEntryError):
pdict.fetch(key2)
# check update
with pytest.raises(ReadOnlyEntryError):
pdict[key2] = 1
# check store_if_not_present
pdict.store_if_not_present(key2, 2)
assert pdict[key1] == 1
finally:
shutil.rmtree(tmpdir)
def test_write_once_persistent_dict_clear():
try:
tmpdir = tempfile.mkdtemp()
pdict = WriteOncePersistentDict("pytools-test", container_dir=tmpdir)
pdict[0] = 1
pdict.fetch(0)
pdict.clear()
with pytest.raises(NoSuchEntryError):
pdict.fetch(0)
finally:
shutil.rmtree(tmpdir)
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
pytest.main([__file__])