in train/comms/pt/comms.py [0:0]
def checkPt2PtRanks(self):
# set default values
if not self.collectiveArgs.src_ranks:
self.collectiveArgs.src_ranks = [0]
if not self.collectiveArgs.dst_ranks:
self.collectiveArgs.dst_ranks = [1]
# sanity check
if self.collectiveArgs.pt2pt == "one2one":
if (
len(self.collectiveArgs.src_ranks) > 1
or len(self.collectiveArgs.dst_ranks) > 1
):
if self.global_rank == 0:
logger.error(
"One2one Pt2Pt requires only a single rank is specified in src_ranks and dst_ranks! "
)
comms_utils.gracefulExit()
elif self.collectiveArgs.pt2pt == "pairwise":
# pairwise pt2pt requires identical number of ranks in src_ranks and dst_ranks.
if len(self.collectiveArgs.src_ranks) != len(self.collectiveArgs.dst_ranks):
if self.global_rank == 0:
logger.error(
"Pairwise Pt2Pt requires identical number of members in src_ranks and dst_ranks! "
)
comms_utils.gracefulExit()
# pairwise pt2pt does not allow same rank to exist in both groups
if bool(
set(self.collectiveArgs.src_ranks).intersection(
self.collectiveArgs.dst_ranks
)
):
if self.global_rank == 0:
logger.error(
"Pairwise Pt2Pt requires distinct members in src_ranks and dst_ranks! "
)
comms_utils.gracefulExit()
if self.global_rank == 0:
print(
f"\t collective={self.collectiveArgs.collective}\t{self.collectiveArgs.pt2pt}, src_ranks={self.collectiveArgs.src_ranks}, dst_ranks={self.collectiveArgs.dst_ranks}"
)