Skip to content
Snippets Groups Projects

Test traversal distribution

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Hao Gao
    Edited
    test_tree_distribution.py 2.28 KiB
    from mpi4py import MPI
    import numpy as np
    import pyopencl as cl
    import time
    
    
    # Get the current rank
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    
    # Configure PyOpenCL
    ctx = cl.create_some_context()
    queue = cl.CommandQueue(ctx)
    
    def run_experiment(dims, nsources, ntargets, dtype):
        tree = None
    
        from boxtree.traversal import FMMTraversalBuilder
        tg = FMMTraversalBuilder(ctx, well_sep_is_n_away=2)
    
        if rank == 0:
    
            start_time = time.time()
    
            # Generate random particles and source weights
            from boxtree.tools import make_normal_particle_array as p_normal
            sources = p_normal(queue, nsources, dims, dtype, seed=15)
            targets = p_normal(queue, ntargets, dims, dtype, seed=18)
    
            from pyopencl.clrandom import PhiloxGenerator
            rng = PhiloxGenerator(queue.context, seed=22)
            target_radii = rng.uniform(
                queue, ntargets, a=0, b=0.05, dtype=np.float64).get()
    
            # Build the tree and interaction lists
            from boxtree import TreeBuilder
            tb = TreeBuilder(ctx)
            tree, _ = tb(queue, sources, targets=targets, target_radii=target_radii,
                         stick_out_factor=0.25, max_particles_in_box=30, debug=True)
    
            queue.finish()
            print("generate tree object on root takes {} sec.".format(time.time() - start_time), flush=True)
    
    
        def fmm_level_to_nterms(tree, level):
            return max(level, 3)
    
    
        def distributed_expansion_wrangler_factory(tree):
            from boxtree.distributed.calculation import DistributedFMMLibExpansionWrangler
    
            return DistributedFMMLibExpansionWrangler(
                queue, tree, 0, fmm_level_to_nterms=fmm_level_to_nterms
            )
    
        comm.Barrier()
        start_time = time.time()
    
        from boxtree.distributed import DistributedFMMInfo
        distribued_fmm_info = DistributedFMMInfo(
            queue, tree, tg, distributed_expansion_wrangler_factory, comm=comm
        )
    
        queue.finish()
        print(
            "distribute traversal object on rank {} takes {} sec.".format(
                rank, time.time() - start_time
            ), flush=True
        )
        comm.Barrier()
    
    if __name__ == "__main__":
        # warm-up
        run_experiment(dims=3, nsources=100000, ntargets=100000, dtype = np.float64)
        # actual run
        run_experiment(dims=3, nsources=10000000, ntargets=10000000, dtype = np.float64)
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment