Skip to content
from __future__ import absolute_import, division, with_statement
import sys
import pytest
@pytest.mark.skipif("sys.version_info < (2, 5)")
def test_memoize_method_clear():
from pytools import memoize_method
class SomeClass:
def __init__(self):
self.run_count = 0
@memoize_method
def f(self):
self.run_count += 1
return 17
sc = SomeClass()
sc.f()
sc.f()
assert sc.run_count == 1
sc.f.clear_cache(sc) # pylint: disable=no-member
def test_memoize_method_with_uncached():
from pytools import memoize_method_with_uncached
class SomeClass:
def __init__(self):
self.run_count = 0
@memoize_method_with_uncached(uncached_args=[1], uncached_kwargs=["z"])
def f(self, x, y, z):
del x, y, z
self.run_count += 1
return 17
sc = SomeClass()
sc.f(17, 18, z=19)
sc.f(17, 19, z=20)
assert sc.run_count == 1
sc.f(18, 19, z=20)
assert sc.run_count == 2
sc.f.clear_cache(sc) # pylint: disable=no-member
def test_memoize_method_nested():
from pytools import memoize_method_nested
class SomeClass:
def __init__(self):
self.run_count = 0
def f(self):
@memoize_method_nested
def inner(x):
self.run_count += 1
return 2*x
inner(5)
inner(5)
sc = SomeClass()
sc.f()
assert sc.run_count == 1
def test_p_convergence_verifier():
pytest.importorskip("numpy")
from pytools.convergence import PConvergenceVerifier
pconv_verifier = PConvergenceVerifier()
for order in [2, 3, 4, 5]:
pconv_verifier.add_data_point(order, 0.1**order)
pconv_verifier()
pconv_verifier = PConvergenceVerifier()
for order in [2, 3, 4, 5]:
pconv_verifier.add_data_point(order, 0.5**order)
pconv_verifier()
pconv_verifier = PConvergenceVerifier()
for order in [2, 3, 4, 5]:
pconv_verifier.add_data_point(order, 2)
with pytest.raises(AssertionError):
pconv_verifier()
def test_memoize():
from pytools import memoize
count = [0]
@memoize(use_kwargs=True)
def f(i, j=1):
count[0] += 1
return i + j
assert f(1) == 2
assert f(1, 2) == 3
assert f(2, j=3) == 5
assert count[0] == 3
assert f(1) == 2
assert f(1, 2) == 3
assert f(2, j=3) == 5
assert count[0] == 3
def test_memoize_keyfunc():
from pytools import memoize
count = [0]
@memoize(key=lambda i, j=(1,): (i, len(j)))
def f(i, j=(1,)):
count[0] += 1
return i + len(j)
assert f(1) == 2
assert f(1, [2]) == 2
assert f(2, j=[2, 3]) == 4
assert count[0] == 2
assert f(1) == 2
assert f(1, (2,)) == 2
assert f(2, j=(2, 3)) == 4
assert count[0] == 2
@pytest.mark.parametrize("dims", [2, 3])
def test_spatial_btree(dims, do_plot=False):
pytest.importorskip("numpy")
import numpy as np
nparticles = 2000
x = -1 + 2*np.random.rand(dims, nparticles)
x = np.sign(x)*np.abs(x)**1.9
x = (1.4 + x) % 2 - 1
bl = np.min(x, axis=-1)
tr = np.max(x, axis=-1)
print(bl, tr)
from pytools.spatial_btree import SpatialBinaryTreeBucket
tree = SpatialBinaryTreeBucket(bl, tr, max_elements_per_box=10)
for i in range(nparticles):
tree.insert(i, (x[:, i], x[:, i]))
if do_plot:
import matplotlib.pyplot as pt
pt.gca().set_aspect("equal")
pt.plot(x[0], x[1], "x")
tree.plot(fill=None)
pt.show()
def test_generate_numbered_unique_names():
from pytools import generate_numbered_unique_names
gen = generate_numbered_unique_names("a")
assert next(gen) == (0, "a")
assert next(gen) == (1, "a_0")
gen = generate_numbered_unique_names("b", 6)
assert next(gen) == (7, "b_6")
def test_find_module_git_revision():
import pytools
print(pytools.find_module_git_revision(pytools.__file__, n_levels_up=1))
def test_reshaped_view():
import pytools
import numpy as np
a = np.zeros((10, 2))
b = a.T
c = pytools.reshaped_view(a, -1)
assert c.shape == (20,)
with pytest.raises(AttributeError):
pytools.reshaped_view(b, -1)
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])