def main()

in benchmark/launch_benchmark.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="all-reduce benchmark")
    parser.add_argument(
        "--init-method",
        type=str,
        default="env://",
        help="How to do rendezvous between machines (uses PyTorch, hence see its doc)",
    )
    parser.add_argument(
        "--machine-idx",
        type=int,
        required=True,
        help="The rank of the machine on which this script was invoked (0-based)",
    )
    parser.add_argument(
        "--num-machines",
        type=int,
        required=True,
        help="On how many machines this script is being invoked (each with its own rank)",
    )
    parser.add_argument(
        "--num-devices-per-machine",
        type=int,
        required=True,
        help="How many clients this script should launch (each will use one GPU)",
    )
    parser.add_argument(
        "--num-buckets",
        type=int,
        required=True,
        help="How many buffers to do an allreduce over in each epoch",
    )
    parser.add_argument(
        "--bucket-size",
        type=int,
        required=True,
        help="How big each buffer should be (expressed in number of float32 elements)",
    )
    parser.add_argument(
        "--num-epochs",
        type=int,
        required=True,
        help="How many times to run the benchmark",
    )
    parser.add_argument(
        "--num-network-threads",
        type=int,
        help="The value of the NCCL_SOCKET_NTHREADS env var (see NCCL's doc)",
    )
    parser.add_argument(
        "--num-sockets-per-network-thread",
        type=int,
        help="The value of the NCCL_NSOCKS_PERTHREAD env var (see NCCL's doc)",
    )
    parser.add_argument(
        "--use-nccl",
        action="store_true",
    )
    # parser.add_argument(
    #     "--pid-file",
    #     type=str,
    # )
    parser.add_argument(
        "--parallelism",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--output",
        type=argparse.FileType("wb"),
        default=sys.stdout.buffer,
    )

    args = parser.parse_args()

    res = run_one_machine(
        init_method=args.init_method,
        machine_idx=args.machine_idx,
        num_machines=args.num_machines,
        num_devices_per_machine=args.num_devices_per_machine,
        num_buckets=args.num_buckets,
        bucket_size=args.bucket_size,
        num_epochs=args.num_epochs,
        num_network_threads=args.num_network_threads,
        num_sockets_per_network_thread=args.num_sockets_per_network_thread,
        use_nccl=args.use_nccl,
        parallelism=args.parallelism,
        # pid_file=args.pid_file,
    )

    torch.save(res, args.output)