in astra-sim-alibabacloud/astra-sim/system/MockNcclGroup.cc [1724:1823]
NVLStreechannels MockNcclGroup::get_nvls_tree_channels(int rank,GroupType type){
std::map<int,std::map<int,std::vector<ncclChannelNode*>>> nvlstreechannels;
std::map<int,std::vector<int>>localrings;
std::map<int,std::vector<int>>::iterator ring_it;
GroupInfo gp_info;
MockNcclLog* NcclLog = MockNcclLog::getInstance();
int current;
int nNodes;
int nlocalRanks;
int delta;
int gp_idx;
if(GroupIndex.count(std::make_pair(rank,type))==0){
NcclLog->writeLog(NcclLogLevel::ERROR,"There is no corresponding group info , resulting in an error in get_nvls_tree_channels.");
return {};
}
gp_idx = GroupIndex[std::make_pair(rank,type)];
gp_info = AllGroups[gp_idx];
if(AllNVLStreechannels.count(gp_idx)){
return AllNVLStreechannels[gp_idx];
}
std::vector<DoubleBinaryTreeNode*>roots;
roots = genInterDouBinTree(gp_info);
nNodes = gp_info.nNodes;
nlocalRanks = gp_info.nRanks/nNodes;
localrings = gen_local_ring(rank,type);
delta = nNodes > 1 ? gp_info.Ranks[nlocalRanks]-gp_info.Ranks[0] : 0;
std::map<int,std::vector<int>>rings;
for(ring_it = localrings.begin();ring_it != localrings.end();ring_it++) {
for(int i = 0; i < nNodes; i++) {
for(int j = 0; j < nlocalRanks; j++) {
current = ring_it->second[j] + i * delta;
rings[ring_it->first].push_back(current);
}
}
}
std::map<int, std::map<int, std::vector<int>>>
allnode2ranks;
for (ring_it = rings.begin(); ring_it != rings.end(); ring_it++) {
int nrankspernode = gp_info.nRanks / nNodes;
for (int i = 0; i < gp_info.nNodes; i++) {
for (int j = 0; j < nrankspernode; j++) {
allnode2ranks[ring_it->first][i].push_back(
ring_it->second[i * nrankspernode + j]);
}
}
}
std::map<int, std::map<int, std::vector<int>>>::iterator allnode2ranks_it;
int channel_id = 0;
std::map<int, std::vector<int>> node2ranks = allnode2ranks[0];
for (DoubleBinaryTreeNode* root : roots) {
for (int index = 0; index < nlocalRanks; index++) {
std::map<int, vector<ncclChannelNode*>> nvlstreechannel;
std::map<int,ncclChannelNode*> nodencclchannlenodes;
for (int i = 0; i < nNodes; i++) {
std::vector<int> noderanks = node2ranks[i];
std::vector<int> intra_topo;
intra_topo.push_back(noderanks[index]);
intra_topo.push_back(gp_info.NVSwitchs[i]);
intra_topo.insert(
intra_topo.end(), noderanks.begin(), noderanks.end());
NcclLog->writeLog(NcclLogLevel::DEBUG," node %d intra_topo",i);
for(auto num:intra_topo){
NcclLog->writeLog(NcclLogLevel::DEBUG," %d",num);
}
ncclChannelNode* root =
gen_nvls_tree_intra_channels(intra_topo, nvlstreechannel);
nodencclchannlenodes[i] = root;
}
std::map<int, std::vector<ncclChannelNode*>>::iterator nvlstreenodes_it;
if (rank == 0) {
for (nvlstreenodes_it = nvlstreechannel.begin();
nvlstreenodes_it != nvlstreechannel.end();
nvlstreenodes_it++) {
NcclLog->writeLog(NcclLogLevel::DEBUG," rank %d nvls tree nodes ",nvlstreenodes_it->first);
int i = 0;
for (auto nvlstreenode : nvlstreenodes_it->second) {
NcclLog->writeLog(NcclLogLevel::DEBUG," node %d rank %d",i,nvlstreenode->rank);
if(nvlstreenode->up!=nullptr)
NcclLog->writeLog(NcclLogLevel::DEBUG," up %d",nvlstreenode->up->rank);
NcclLog->writeLog(NcclLogLevel::DEBUG," down ");
for (auto down : nvlstreenode->down) {
NcclLog->writeLog(NcclLogLevel::DEBUG," %d ",down->rank);
}
}
}
}
gen_nvls_tree_inter_channels(
root, nodencclchannlenodes, nvlstreechannel);
nvlstreechannels[channel_id] = nvlstreechannel;
channel_id++;
}
}
AllNVLStreechannels[gp_idx] = nvlstreechannels;
return nvlstreechannels;
}