in benchmark/launch_benchmark.py [0:0]
def main():
parser = argparse.ArgumentParser(description="all-reduce benchmark")
parser.add_argument(
"--init-method",
type=str,
default="env://",
help="How to do rendezvous between machines (uses PyTorch, hence see its doc)",
)
parser.add_argument(
"--machine-idx",
type=int,
required=True,
help="The rank of the machine on which this script was invoked (0-based)",
)
parser.add_argument(
"--num-machines",
type=int,
required=True,
help="On how many machines this script is being invoked (each with its own rank)",
)
parser.add_argument(
"--num-devices-per-machine",
type=int,
required=True,
help="How many clients this script should launch (each will use one GPU)",
)
parser.add_argument(
"--num-buckets",
type=int,
required=True,
help="How many buffers to do an allreduce over in each epoch",
)
parser.add_argument(
"--bucket-size",
type=int,
required=True,
help="How big each buffer should be (expressed in number of float32 elements)",
)
parser.add_argument(
"--num-epochs",
type=int,
required=True,
help="How many times to run the benchmark",
)
parser.add_argument(
"--num-network-threads",
type=int,
help="The value of the NCCL_SOCKET_NTHREADS env var (see NCCL's doc)",
)
parser.add_argument(
"--num-sockets-per-network-thread",
type=int,
help="The value of the NCCL_NSOCKS_PERTHREAD env var (see NCCL's doc)",
)
parser.add_argument(
"--use-nccl",
action="store_true",
)
# parser.add_argument(
# "--pid-file",
# type=str,
# )
parser.add_argument(
"--parallelism",
type=int,
default=None,
)
parser.add_argument(
"--output",
type=argparse.FileType("wb"),
default=sys.stdout.buffer,
)
args = parser.parse_args()
res = run_one_machine(
init_method=args.init_method,
machine_idx=args.machine_idx,
num_machines=args.num_machines,
num_devices_per_machine=args.num_devices_per_machine,
num_buckets=args.num_buckets,
bucket_size=args.bucket_size,
num_epochs=args.num_epochs,
num_network_threads=args.num_network_threads,
num_sockets_per_network_thread=args.num_sockets_per_network_thread,
use_nccl=args.use_nccl,
parallelism=args.parallelism,
# pid_file=args.pid_file,
)
torch.save(res, args.output)