in astra-sim-alibabacloud/astra-sim/system/MockNcclGroup.cc [1854:1920]
TreeChannels MockNcclGroup::gettreechannels(int rank, GroupType type){
TreeChannels treechannels;
std::map<int,std::vector<int>>localrings;
std::map<int,std::vector<int>>::iterator ring_it;
GroupInfo gp_info;
int gp_idx;
int current;
int nNodes;
int nlocalRanks;
int delta;
MockNcclLog* NcclLog = MockNcclLog::getInstance();
if(GroupIndex.count(std::make_pair(rank,type))==0){
NcclLog->writeLog(NcclLogLevel::ERROR,"There is no corresponding group info and group ring channel, resulting in an error in gettreechannels.");
return {};
}
gp_idx = GroupIndex[std::make_pair(rank,type)];
gp_info = AllGroups[gp_idx];
if(Alltreechannels.count(gp_idx)){
return Alltreechannels[gp_idx];
}
nNodes = gp_info.nNodes;
nlocalRanks = gp_info.nRanks/nNodes;
localrings = gen_local_ring(rank,type);
delta = nNodes > 1 ? gp_info.Ranks[nlocalRanks]-gp_info.Ranks[0] : 0;
std::map<int,std::vector<int>>rings;
for(ring_it = localrings.begin();ring_it != localrings.end();ring_it++) {
for(int i = 0; i < nNodes; i++) {
for(int j = 0; j < nlocalRanks; j++) {
current = ring_it->second[j] + i * delta;
rings[ring_it->first].push_back(current);
}
}
}
std::vector<DoubleBinaryTreeNode*> roots;
roots = genInterDouBinTree(gp_info);
std::map<int, std::map<int, std::vector<int>>>
allnode2ranks;
for (ring_it = rings.begin(); ring_it != rings.end(); ring_it++) {
int nrankspernode = gp_info.nRanks / nNodes;
for (int i = 0; i < gp_info.nNodes; i++) {
for (int j = 0; j < nrankspernode; j++) {
allnode2ranks[ring_it->first][i].push_back(
ring_it->second[i * nrankspernode + j]);
}
}
}
std::map<int, std::map<int, std::vector<int>>>::iterator allnode2ranks_it;
int channel_id = 0;
for (allnode2ranks_it = allnode2ranks.begin();
allnode2ranks_it != allnode2ranks.end();
allnode2ranks_it++) {
std::map<int, std::vector<int>> node2ranks = allnode2ranks_it->second;
for (DoubleBinaryTreeNode* root : roots) {
std::map<int, ncclTree> treechannel;
for (int rank : gp_info.Ranks) {
ncclTree cur = ncclTree(-1, rank, -1, {});
treechannel[rank] = cur;
}
ConnInterIntraTree(root, node2ranks, treechannel);
treechannels[channel_id] = treechannel;
channel_id++;
}
Alltreechannels[gp_idx] = treechannels;
}
return treechannels;
}