Skip to content
test_tree.py 36.1 KiB
Newer Older
Matt Wala's avatar
Matt Wala committed
    nballs = 10**4
    ball_centers = make_normal_particle_array(queue, nballs, dims, dtype)
    ball_radii = cl.array.empty(queue, nballs, dtype).fill(0.1)

    from boxtree.area_query import (
        LeavesToBallsLookupBuilder, SpaceInvaderQueryBuilder)

    siqb = SpaceInvaderQueryBuilder(ctx)
    # We can use leaves-to-balls lookup to get the set of overlapping balls for
    # each box, and from there to compute the outer space invader distance.
    lblb = LeavesToBallsLookupBuilder(ctx)

    siq, _ = siqb(queue, tree, ball_centers, ball_radii)
    lbl, _ = lblb(queue, tree, ball_centers, ball_radii)

    # get data to host for test
    tree = tree.get(queue=queue)
    siq = siq.get(queue=queue)
    lbl = lbl.get(queue=queue)

    ball_centers = np.array([x.get() for x in ball_centers])
    ball_radii = ball_radii.get()

    # Find leaf boxes.
    from boxtree import box_flags_enum

    outer_space_invader_dist = np.zeros(tree.nboxes)

    for ibox in range(tree.nboxes):
        # We only want leaves here.
        if tree.box_flags[ibox] & box_flags_enum.HAS_CHILDREN:
            continue

        start, end = lbl.balls_near_box_starts[ibox:ibox + 2]
        space_invaders = lbl.balls_near_box_lists[start:end]
        if len(space_invaders) > 0:
            outer_space_invader_dist[ibox] = np.max(np.abs(
                    tree.box_centers[:, ibox].reshape((-1, 1))
                    - ball_centers[:, space_invaders]))

    assert np.allclose(siq, outer_space_invader_dist)

# }}}


# {{{ test_same_tree_with_zero_weight_particles

@pytest.mark.parametrize("dims", [2, 3])
def test_same_tree_with_zero_weight_particles(ctx_factory, dims):
    logging.basicConfig(level=logging.INFO)

    ntargets_values = [300, 400, 500]
    stick_out_factors = [0, 0.1, 0.3, 1]
    nsources = 20

    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    from boxtree import TreeBuilder
    tb = TreeBuilder(ctx)

    trees = []

    for stick_out_factor in stick_out_factors:
        for ntargets in [40]:
            np.random.seed(10)
            sources = np.random.rand(dims, nsources)**2
            sources[:, 0] = -0.1
            sources[:, 1] = 1.1

            np.random.seed()
            targets = np.random.rand(dims, max(ntargets_values))[:, :ntargets].copy()
            target_radii = np.random.rand(max(ntargets_values))[:ntargets]

            sources = cl.array.to_device(queue, sources)
            targets = cl.array.to_device(queue, targets)

            refine_weights = cl.array.empty(queue, nsources + ntargets, np.int32)
            refine_weights[:nsources] = 1
            refine_weights[nsources:] = 0

            tree, _ = tb(queue, sources, targets=targets,
                    target_radii=target_radii,
                    stick_out_factor=stick_out_factor,
                    max_leaf_refine_weight=10,
                    refine_weights=refine_weights,
                    debug=True)
            tree = tree.get(queue=queue)
            trees.append(tree)

            print("TREE:", tree.nboxes)

    if 0:
        import matplotlib.pyplot as plt
        for tree in trees:
            plt.figure()
            tree.plot()

        plt.show()

# {{{ test_max_levels_error

def test_max_levels_error(ctx_factory):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    from boxtree import TreeBuilder
    tb = TreeBuilder(ctx)

    logging.basicConfig(level=logging.INFO)

    sources = [cl.array.zeros(queue, 11, float) for i in range(2)]
    from boxtree.tree_build import MaxLevelsExceeded
    with pytest.raises(MaxLevelsExceeded):
        tree, _ = tb(queue, sources, max_particles_in_box=10, debug=True)

# }}}


Andreas Klöckner's avatar
Andreas Klöckner committed
# You can test individual routines by typing
Andreas Klöckner's avatar
Andreas Klöckner committed
# $ python test_tree.py 'test_routine(cl.create_some_context)'
Andreas Klöckner's avatar
Andreas Klöckner committed

if __name__ == "__main__":
    if len(sys.argv) > 1:
        exec(sys.argv[1])
    else:
        from pytest import main
Andreas Klöckner's avatar
Andreas Klöckner committed
        main([__file__])

# vim: fdm=marker