def timed_broadcast()

in bench_cluster/communication/broadcast.py [0:0]


def timed_broadcast(input, start_event, end_event, warmups, trials, async_op, bw_unit, raw):
    sync_all()
    # Warmups, establish connections, etc.
    for i in range(warmups):
        dist.broadcast(input, 0, async_op=async_op)
    sync_all()

    # time the actual comm op trials times and average it
    start_event.record()
    for i in range(trials):
        dist.broadcast(input, 0, async_op=async_op)
    end_event.record()
    sync_all()
    duration = start_event.elapsed_time(end_event) / 1000

    # maintain and clean performance data
    avg_duration = duration / trials
    size = input.element_size() * input.nelement()
    n = dist.get_world_size()
    tput, busbw = get_bw(bw_unit, 'broadcast', size, avg_duration)
    tput_str, busbw_str, duration_str = get_metric_strings(raw, tput, busbw, avg_duration)
    desc = f'{input.nelement()}x{input.element_size()}'

    if not raw:
        size = convert_size(size)

    print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")