in train/comms/pt/comms.py [0:0]
def checkArgs(self, args):
super().checkArgs(args)
if args.pt2pt is not None:
args.collective = "pt2pt"
if args.pt2pt not in pt2ptPatterns:
logger.error(
f"Specified pt2pt pattern: {args.pt2pt} is not one of the supported pt2pt patterns: {str(pt2ptPatterns)}"
)
comms_utils.gracefulExit()
args.b = comms_utils.parsesize(args.b)
args.e = comms_utils.parsesize(args.e)
args.dtype = self.dtypeMap[args.data_type]
if args.b < 1:
logger.warning(
f"Starting size (--b {args.b}) should be greater than 1 byte...fix and continue"
)
args.b = 1
if args.e < args.b:
logger.warning(
f"the begin-size (--b {args.b}) is larger than the end-size (--e {args.e})"
)
if args.device == "cpu" and args.backend == "nccl":
raise ValueError(f"NCCL is not supported for device type {args.device}")
if args.c == 1 and args.z == 0 and args.collective in ("all_reduce", "reduce", "reduce_scatter"):
logger.warning(
f"Data validation is not supported for {args.collective} in non-blocking mode, disabled and continue"
)
args.c = 0
# run a few sanity checks
if args.bitwidth < 32:
if args.device != "cuda":
logger.error(
f"collective quantization may not be fully supported for {args.device}"
)
comms_utils.checkQuantArgs(
args.collective,
args.dtype,
args.b,
args.quant_a2a_embedding_dim,
args.z,
)