in train/comms/pt/comms.py [0:0]
def readArgs(self, parser):
# read the common/basic arguments
super().readArgs(parser)
parser.add_argument(
"--w", type=int, default=5, help="number of warmup iterations"
) # number of warmup-iterations
parser.add_argument(
"--n", type=int, default=5, help="number of iterations"
) # number of iterations
# experiment related parameters
parser.add_argument(
"--mode",
type=str,
default="comms",
help="benchmark mode",
choices=["comms", "compute", "dlrm", "comms-compute"],
) # alternative is DLRM mode or comm-compute mode
parser.add_argument(
"--b", type=str, default="8", help="minimum size, in bytes, to start with"
) # COMMS mode, begin the sweep at.
parser.add_argument(
"--e", type=str, default="64", help="maximum size, in bytes, to end at"
) # COMMS mode, end the sweep at.
parser.add_argument(
"--f", type=int, default=2, help="multiplication factor between sizes"
) # COMMS mode, multiplication factor.
parser.add_argument(
"--collective",
type=str,
default="all_reduce",
help="Collective operation to be evaluated",
choices=supportedCollectives,
) # collective op to benchmark
# For comm-compute or compute mode
parser.add_argument(
"--kernel",
type=str,
default="gemm",
help="Compute kernel, used for comms-compute or compute mode",
choices=["gemm", "emb_lookup"],
) # Compute kernel: "gemm"
parser.add_argument(
"--num-compute",
type=int,
default=100,
help="one collective for every NUM_COMPUTE compute kernels",
) # Launch one coll for every n compute kernels
# For GEMM
parser.add_argument(
"--mm-dim",
type=int,
default=100,
help="dimension size for GEMM compute kernel",
) # Matrix multiplication dim n, A[n,n] * B [n,n]
# For emb lookup
parser.add_argument(
"--emb-dim",
type=int,
default=128,
help="dimension size for Embedding table compute kernel",
) # Embedding table dimension
parser.add_argument(
"--num-embs",
type=int,
default=100000,
help="Embedding table hash size for Embedding table compute kernel",
) # Embedding table hash size
parser.add_argument(
"--avg-len",
type=int,
default=28,
help="Average lookup operations per sample",
) # Average #lookup per sample
parser.add_argument(
"--batch-size",
type=int,
default=512,
help="number of samples reading the table concurrently",
) # #Samples reading the table concurrently
parser.add_argument(
"--root", type=int, default=0, help="root process for reduce benchmark"
) # root process for reduce and bcast (and gather, scatter, etc., if support in the future)
# TODO: check the correctness of root, should be between 0 to [world_size -1]
parser.add_argument(
"--src-ranks",
type=str,
nargs="?",
help="R|src ranks for many-to-one incast pattern or pt2pt.\n"
"List of ranks separated by comma or a range specified by start:end.\n"
"Pt2pt one2one should set only one rank.\n"
"The default value of incast includes all ranks, pt2pt includes rank 0.",
) # optional: group of src ranks in many-to-one incast or pt2pt
parser.add_argument(
"--dst-ranks",
type=str,
nargs="?",
help="R|dst ranks for one-to-many multicast pattern or pt2pt.\n"
"List of ranks separated by comma or a range specified by start:end.\n"
"Pt2pt one2one should set only one rank\n"
"The default value of multicast includes all ranks, pt2pt includes rank 1.",
) # optional: group of dst ranks in one-to-many multicast or pt2pt
parser.add_argument(
"--pair",
action="store_true",
default=False,
help="Toggle to enable collective pair mode",
)
parser.add_argument(
"--collective-pair",
type=str,
default="all_reduce",
help="Collective pair operation to be evaluated",
choices=supportedCollectives,
) # collective op to pair with the other collective, --collective should be non-empty
parser.add_argument(
"--overlap-pair-pgs",
action="store_true",
default=False,
help="Toggle to enable overlapping collective pair with two pgs",
) # overlap collective pair with two pgs
parser.add_argument(
"--pt2pt",
type=str,
default=None,
help="point to point pattern",
choices=pt2ptPatterns,
) # point to point mode
parser.add_argument(
"--window",
type=int,
default=100,
help="window size for pt2pt throughput test",
) # optional: point to point throughput test window size
return parser.parse_known_args()