in train/comms/pt/comms.py [0:0]
def checkCollectiveRanks(self):
if self.collectiveArgs.collective == "incast":
# incast: set default value and exclude root
if not self.collectiveArgs.src_ranks:
self.collectiveArgs.src_ranks = [*range(self.comm_size)]
if self.collectiveArgs.srcOrDst in self.collectiveArgs.src_ranks:
self.collectiveArgs.src_ranks.remove(self.collectiveArgs.srcOrDst)
elif self.collectiveArgs.collective == "multicast":
# multicast: set default value and exclude root
if not self.collectiveArgs.dst_ranks:
self.collectiveArgs.dst_ranks = [*range(self.comm_size)]
if self.collectiveArgs.srcOrDst in self.collectiveArgs.dst_ranks:
self.collectiveArgs.dst_ranks.remove(self.collectiveArgs.srcOrDst)
if self.global_rank == 0:
print(
f"\t collective={self.collectiveArgs.collective}, src_ranks={self.collectiveArgs.src_ranks}, dst_ranks={self.collectiveArgs.dst_ranks}"
)