in train/comms/pt/comms_utils.py [0:0]
def readArgs(self, parser):
"""Basic/Common arguments for all PARAM-Comm benchmarks"""
parser.add_argument(
"--master-ip",
type=str,
default=default_master_ip,
help="The master-IP to coordinate",
) # The master-IP to coordinate.
parser.add_argument(
"--master-port",
type=str,
default=default_master_port,
help="The master-port to coordinate",
) # The master-port to coordinate.
parser.add_argument(
"--nw-stack",
type=str,
default="pytorch-dist",
help="network stack to be used, supports " + str(self.supportedNwstacks),
) # The network stack to profile.
parser.add_argument(
"--dtype", type=torch.dtype, default=torch.float32
) # will be overwritten based on args.data_type and dtypeMap.
parser.add_argument(
"--data-type",
type=str,
default="float32",
help="the base data type, supports " + str(self.supportedDtype),
) # The data type
parser.add_argument(
"--num-tpu-cores",
type=int,
default=1,
help="number of TPU cores to be used",
) # number of TPU cores
parser.add_argument(
"--log",
type=str,
default="ERROR",
help="Logging level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
) # logging level
parser.add_argument(
"--device",
type=str,
default=("cuda" if self.isCudaAvail() else "cpu"),
choices=["cpu", "cuda", "tpu"],
help="data placement",
) # device to place data for collective benchmarking
parser.add_argument(
"--backend",
type=str,
default=("nccl" if self.isCudaAvail() else "mpi"),
help="The backend to be used in PyTorch distributed process group",
choices=["nccl", "gloo", "mpi", "ucc", "xla"],
) # backend used for the network stack
parser.add_argument(
"--z",
type=int,
default=1,
help="use blocking mode for collectives",
choices=[0, 1],
) # 'sync/blocking' : 1 , 'async/non-blocking' : 0
parser.add_argument(
"--bitwidth",
type=int,
default=32,
help="Quantization bitwidth",
choices=[2, 4, 8, 16, 32],
) # comms quantization
parser.add_argument(
"--quant-a2a-embedding-dim",
type=int,
default=32,
help="Embedding dimension used by quantization alltoall if enabled",
choices=[32, 64, 128, 256],
) # Row dimension for quantization
parser.add_argument(
"--quant-threshold",
type=int,
default=33554432,
help="threshold of message sizes to perform quantization if enabled",
) # quantization threshold, default 32 MB
parser.add_argument(
"--c",
type=int,
default=0,
help="enable data validation check",
choices=[0, 1],
) # validation check
pass