TreeChannels MockNcclGroup::gettreechannels()

in astra-sim-alibabacloud/astra-sim/system/MockNcclGroup.cc [1854:1920]


  TreeChannels MockNcclGroup::gettreechannels(int rank, GroupType type){
    TreeChannels treechannels;
    std::map<int,std::vector<int>>localrings;
    std::map<int,std::vector<int>>::iterator ring_it;
    GroupInfo gp_info;
    int gp_idx;
    int current;
    int nNodes;
    int nlocalRanks;
    int delta;
    MockNcclLog* NcclLog = MockNcclLog::getInstance();
    if(GroupIndex.count(std::make_pair(rank,type))==0){
      NcclLog->writeLog(NcclLogLevel::ERROR,"There is no corresponding group info and group ring channel, resulting in an error in gettreechannels.");
      return {};
    }
    gp_idx = GroupIndex[std::make_pair(rank,type)];
    gp_info = AllGroups[gp_idx];
    if(Alltreechannels.count(gp_idx)){
      return Alltreechannels[gp_idx];
    }
  
    nNodes = gp_info.nNodes;
    nlocalRanks = gp_info.nRanks/nNodes;
    localrings = gen_local_ring(rank,type);
    delta = nNodes > 1 ? gp_info.Ranks[nlocalRanks]-gp_info.Ranks[0] : 0;
    std::map<int,std::vector<int>>rings;
    for(ring_it = localrings.begin();ring_it != localrings.end();ring_it++) {
      for(int i = 0; i < nNodes; i++) {
        for(int j = 0; j < nlocalRanks; j++) {
          current = ring_it->second[j] + i * delta; 
          rings[ring_it->first].push_back(current);
        }
      }
    }
    std::vector<DoubleBinaryTreeNode*> roots;
    roots = genInterDouBinTree(gp_info);
    std::map<int, std::map<int, std::vector<int>>>
        allnode2ranks; 
    for (ring_it = rings.begin(); ring_it != rings.end(); ring_it++) {
      int nrankspernode = gp_info.nRanks / nNodes;
      for (int i = 0; i < gp_info.nNodes; i++) {
        for (int j = 0; j < nrankspernode; j++) {
          allnode2ranks[ring_it->first][i].push_back(
              ring_it->second[i * nrankspernode + j]);
        }
      }
    }
    std::map<int, std::map<int, std::vector<int>>>::iterator allnode2ranks_it;
    int channel_id = 0;
    for (allnode2ranks_it = allnode2ranks.begin();
         allnode2ranks_it != allnode2ranks.end();
         allnode2ranks_it++) {
      std::map<int, std::vector<int>> node2ranks = allnode2ranks_it->second;
      for (DoubleBinaryTreeNode* root : roots) {
        std::map<int, ncclTree> treechannel;
        for (int rank : gp_info.Ranks) {
          ncclTree cur =  ncclTree(-1, rank, -1, {});
          treechannel[rank] = cur;
        }
        ConnInterIntraTree(root, node2ranks, treechannel);
        treechannels[channel_id] = treechannel;
        channel_id++;
      }
      Alltreechannels[gp_idx] = treechannels;
    }
    return treechannels;
  }