Skip to content
test_pytools.py 4.35 KiB
Newer Older
Andreas Klöckner's avatar
Andreas Klöckner committed
from __future__ import absolute_import, division, with_statement
Andreas Klöckner's avatar
Andreas Klöckner committed
import pytest

@pytest.mark.skipif("sys.version_info < (2, 5)")
def test_memoize_method_clear():
    from pytools import memoize_method

    class SomeClass:
        def __init__(self):
            self.run_count = 0

        @memoize_method
        def f(self):
            self.run_count += 1
            return 17

    sc = SomeClass()
    sc.f()
    assert sc.run_count == 1
Andreas Klöckner's avatar
Andreas Klöckner committed
    sc.f.clear_cache(sc)  # pylint: disable=no-member
def test_memoize_method_with_uncached():
    from pytools import memoize_method_with_uncached

    class SomeClass:
        def __init__(self):
            self.run_count = 0

        @memoize_method_with_uncached(uncached_args=[1], uncached_kwargs=["z"])
        def f(self, x, y, z):
Andreas Klöckner's avatar
Andreas Klöckner committed
            del x, y, z
            self.run_count += 1
            return 17

    sc = SomeClass()
    sc.f(17, 18, z=19)
    sc.f(17, 19, z=20)
    assert sc.run_count == 1
    sc.f(18, 19, z=20)
    assert sc.run_count == 2
Andreas Klöckner's avatar
Andreas Klöckner committed
    sc.f.clear_cache(sc)  # pylint: disable=no-member
def test_memoize_method_nested():
    from pytools import memoize_method_nested

    class SomeClass:
        def __init__(self):
            self.run_count = 0

        def f(self):

            @memoize_method_nested
            def inner(x):
                self.run_count += 1
                return 2*x

            inner(5)
            inner(5)

    sc = SomeClass()
    sc.f()
    assert sc.run_count == 1


def test_p_convergence_verifier():
    pytest.importorskip("numpy")

    from pytools.convergence import PConvergenceVerifier

    pconv_verifier = PConvergenceVerifier()
    for order in [2, 3, 4, 5]:
        pconv_verifier.add_data_point(order, 0.1**order)
    pconv_verifier()

    pconv_verifier = PConvergenceVerifier()
    for order in [2, 3, 4, 5]:
        pconv_verifier.add_data_point(order, 0.5**order)
    pconv_verifier()

    pconv_verifier = PConvergenceVerifier()
    for order in [2, 3, 4, 5]:
        pconv_verifier.add_data_point(order, 2)
    with pytest.raises(AssertionError):
        pconv_verifier()


def test_memoize():
    from pytools import memoize
    count = [0]

    def f(i, j=1):
        count[0] += 1
        return i + j

    assert f(1) == 2
    assert f(1, 2) == 3
    assert f(2, j=3) == 5
    assert count[0] == 3
    assert f(1) == 2
    assert f(1, 2) == 3
    assert f(2, j=3) == 5
    assert count[0] == 3


def test_memoize_keyfunc():
    from pytools import memoize
    count = [0]

    @memoize(key=lambda i, j=(1,): (i, len(j)))
    def f(i, j=(1,)):
        count[0] += 1
        return i + len(j)

    assert f(1) == 2
    assert f(1, [2]) == 2
    assert f(2, j=[2, 3]) == 4
    assert count[0] == 2
    assert f(1) == 2
    assert f(1, (2,)) == 2
    assert f(2, j=(2, 3)) == 4
    assert count[0] == 2


@pytest.mark.parametrize("dims", [2, 3])
def test_spatial_btree(dims, do_plot=False):
    pytest.importorskip("numpy")
    import numpy as np
    nparticles = 2000
    x = -1 + 2*np.random.rand(dims, nparticles)
    x = np.sign(x)*np.abs(x)**1.9
    x = (1.4 + x) % 2 - 1

    bl = np.min(x, axis=-1)
    tr = np.max(x, axis=-1)
    print(bl, tr)

    from pytools.spatial_btree import SpatialBinaryTreeBucket
    tree = SpatialBinaryTreeBucket(bl, tr, max_elements_per_box=10)
    for i in range(nparticles):
        tree.insert(i, (x[:, i], x[:, i]))

    if do_plot:
        import matplotlib.pyplot as pt
        pt.gca().set_aspect("equal")
        pt.plot(x[0], x[1], "x")
        tree.plot(fill=None)
        pt.show()


def test_generate_numbered_unique_names():
    from pytools import generate_numbered_unique_names

    gen = generate_numbered_unique_names("a")
    assert next(gen) == (0, "a")
    assert next(gen) == (1, "a_0")

    gen = generate_numbered_unique_names("b", 6)
    assert next(gen) == (7, "b_6")


def test_find_module_git_revision():
    import pytools
    print(pytools.find_module_git_revision(pytools.__file__, n_levels_up=1))


Hao Gao's avatar
Hao Gao committed
def test_reshaped_view():
    import pytools
    import numpy as np

    a = np.zeros((10, 2))
    b = a.T
    c = pytools.reshaped_view(a, -1)
    assert c.shape == (20,)
    with pytest.raises(AttributeError):
        pytools.reshaped_view(b, -1)


if __name__ == "__main__":
    if len(sys.argv) > 1:
        exec(sys.argv[1])
    else:
        from pytest import main
        main([__file__])