def timed_p2p()

in bench_cluster/communication/p2p.py [0:0]


def timed_p2p(input, start_event, end_event, warmups, trials, async_op, bw_unit, raw):
    world_size = dist.get_world_size()
    rank = dist.get_rank()

    sync_all()
    # Warmups, establish connections, etc.
    for _ in range(warmups):
        for i in range(world_size):
            if i != rank:
                if async_op:
                    if rank < i:
                        dist.isend(input, i)
                    else:
                        dist.irecv(input, src=i)
                else:
                    if rank < i:
                        dist.send(input, i)
                    else:
                        dist.recv(input, src=i)
    sync_all()

    # time the actual comm op trials times and average it
    start_event.record()
    for _ in range(trials):
        for i in range(world_size):
            if i != rank:
                if async_op:
                    if rank < i:
                        dist.isend(input, i)
                    else:
                        dist.irecv(input, src=i)
                else:
                    if rank < i:
                        dist.send(input, i)
                    else:
                        dist.recv(input, src=i)

    end_event.record()
    sync_all()
    duration = start_event.elapsed_time(end_event) / 1000

    # maintain and clean performance data
    avg_duration = duration / trials
    size = input.element_size() * input.nelement()
    n = world_size
    tput, busbw = get_bw(bw_unit, 'p2p', size * (n - 1), avg_duration)  # Multiply size by (n-1) as each process communicates with all others
    tput_str, busbw_str, duration_str = get_metric_strings(raw, tput, busbw, avg_duration)
    desc = f'{input.nelement()}x{input.element_size()}'

    if not raw:
        size = convert_size(size)

    print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")