NVLStreechannels MockNcclGroup::get_nvls_tree_channels()

in astra-sim-alibabacloud/astra-sim/system/MockNcclGroup.cc [1724:1823]


  NVLStreechannels MockNcclGroup::get_nvls_tree_channels(int rank,GroupType type){
    std::map<int,std::map<int,std::vector<ncclChannelNode*>>> nvlstreechannels;
    std::map<int,std::vector<int>>localrings;
    std::map<int,std::vector<int>>::iterator ring_it;
    GroupInfo gp_info;
    MockNcclLog* NcclLog = MockNcclLog::getInstance();
    int current;
    int nNodes;
    int nlocalRanks;
    int delta;
    int gp_idx;
    if(GroupIndex.count(std::make_pair(rank,type))==0){
      NcclLog->writeLog(NcclLogLevel::ERROR,"There is no corresponding group info , resulting in an error in get_nvls_tree_channels.");
      return {};
    }
    gp_idx = GroupIndex[std::make_pair(rank,type)];
    gp_info = AllGroups[gp_idx];
    if(AllNVLStreechannels.count(gp_idx)){
      return AllNVLStreechannels[gp_idx];
    }
    std::vector<DoubleBinaryTreeNode*>roots;
    roots = genInterDouBinTree(gp_info);

    nNodes = gp_info.nNodes;
    nlocalRanks = gp_info.nRanks/nNodes;
    localrings = gen_local_ring(rank,type);
    delta = nNodes > 1 ? gp_info.Ranks[nlocalRanks]-gp_info.Ranks[0] : 0;
    std::map<int,std::vector<int>>rings;
    for(ring_it = localrings.begin();ring_it != localrings.end();ring_it++) {
      for(int i = 0; i < nNodes; i++) {
        for(int j = 0; j < nlocalRanks; j++) {
          current = ring_it->second[j] + i * delta;
          rings[ring_it->first].push_back(current);
        }
      }
    }
    std::map<int, std::map<int, std::vector<int>>>
        allnode2ranks; 
    for (ring_it = rings.begin(); ring_it != rings.end(); ring_it++) {
      int nrankspernode = gp_info.nRanks / nNodes;
      for (int i = 0; i < gp_info.nNodes; i++) {
        for (int j = 0; j < nrankspernode; j++) {
          allnode2ranks[ring_it->first][i].push_back(
              ring_it->second[i * nrankspernode + j]);
        }
      }
    }

    std::map<int, std::map<int, std::vector<int>>>::iterator allnode2ranks_it;
    int channel_id = 0;
    std::map<int, std::vector<int>> node2ranks = allnode2ranks[0];
    for (DoubleBinaryTreeNode* root : roots) {
      for (int index = 0; index < nlocalRanks; index++) {
        std::map<int, vector<ncclChannelNode*>> nvlstreechannel;
        std::map<int,ncclChannelNode*> nodencclchannlenodes;
        for (int i = 0; i < nNodes; i++) {
          std::vector<int> noderanks = node2ranks[i];
          std::vector<int> intra_topo;
          intra_topo.push_back(noderanks[index]);
          intra_topo.push_back(gp_info.NVSwitchs[i]);
          intra_topo.insert(
              intra_topo.end(), noderanks.begin(), noderanks.end());
          NcclLog->writeLog(NcclLogLevel::DEBUG," node  %d intra_topo",i);
          for(auto num:intra_topo){
            NcclLog->writeLog(NcclLogLevel::DEBUG," %d",num);
          }
          ncclChannelNode* root =
              gen_nvls_tree_intra_channels(intra_topo, nvlstreechannel);
          nodencclchannlenodes[i] = root;
        }

        std::map<int, std::vector<ncclChannelNode*>>::iterator nvlstreenodes_it;
        if (rank == 0) {
          for (nvlstreenodes_it = nvlstreechannel.begin();
               nvlstreenodes_it != nvlstreechannel.end();
               nvlstreenodes_it++) {
              NcclLog->writeLog(NcclLogLevel::DEBUG," rank  %d nvls tree nodes ",nvlstreenodes_it->first);
            int i = 0;
            for (auto nvlstreenode : nvlstreenodes_it->second) {
              NcclLog->writeLog(NcclLogLevel::DEBUG," node  %d rank  %d",i,nvlstreenode->rank);
              if(nvlstreenode->up!=nullptr)
                NcclLog->writeLog(NcclLogLevel::DEBUG," up  %d",nvlstreenode->up->rank);
              NcclLog->writeLog(NcclLogLevel::DEBUG," down ");
              for (auto down : nvlstreenode->down) {
                NcclLog->writeLog(NcclLogLevel::DEBUG," %d ",down->rank);
              }
            }
          }
        }

        gen_nvls_tree_inter_channels(
            root, nodencclchannlenodes, nvlstreechannel);

        nvlstreechannels[channel_id] = nvlstreechannel;
        channel_id++;
      }
    }
    AllNVLStreechannels[gp_idx] = nvlstreechannels;
    return nvlstreechannels;
  }