void NcclTreeFlowModel::run()

in astra-sim-alibabacloud/astra-sim/system/collective/NcclTreeFlowModel.cc [134:295]


void NcclTreeFlowModel::run(EventType event, CallData* data) {
  BasicEventHandlerData* ehd = (BasicEventHandlerData*)data;
  MockNcclLog* NcclLog = MockNcclLog::getInstance();
  if (event == EventType::General) {
    int channel_id = ehd->channel_id;
    int flow_id = ehd->flow_id;
    #ifndef PHY_MTP
    ready(channel_id, flow_id);
    #else
    phy_ready(channel_id, flow_id);
    #endif
  } else if (event == EventType::PacketReceived) {
    MockNcclLog* NcclLog = MockNcclLog::getInstance();
    RecvPacketEventHadndlerData* rcehd = (RecvPacketEventHadndlerData*)ehd;
    AstraSim::ncclFlowTag flowTag = rcehd->flowTag;
    int received_flow_id = flowTag.current_flow_id;
    int channel_id = flowTag.channel_id;
    std::vector<int> next_flow_list = flowTag.tree_flow_list;    
    #ifdef PHY_MTP
    recv_packets--;
    if(!phy_iteratable(channel_id)){
      return;
    }
    #else 
    bool flow_exist = next_flow_list.size() == 0 ? true : false;
    for(int i = 0; i < next_flow_list.size(); ++ i) {
      int next_flow_id = next_flow_list[i];
      if(next_flow_id == -1 || _flow_models.count(std::make_pair(channel_id, next_flow_id)) != 0) flow_exist = true;
      else {
        flow_exist = false;
        break;
      }
    }
    assert(flow_exist == true);
    NcclTreeFlowModel::FlowCriticalSection cs;
    free_packets[std::make_pair(channel_id, flowTag.sender_node)]--;
    bool tag = true;
    for (int i = 0; i < m_channels; i++) {
      if (_stream_count[i] != 0) {
        tag = false;
        break;
      }
    }
    cs.ExitSection();
    if(tag) { 
      ready(channel_id, -1);
      iteratable(channel_id);
      return;
    } 
    #endif
    NcclLog->writeLog(NcclLogLevel::DEBUG,"PacketReceived sender_node:  %d recevier  %d current_flow id:  %d channel_id:  %d tag_id  %d free_packets  %d next_flow_list.size %d",flowTag.sender_node,flowTag.receiver_node,flowTag.current_flow_id,flowTag.channel_id,flowTag.tag_id,free_packets[std::make_pair(channel_id,flowTag.sender_node)],next_flow_list.size());
    #ifdef PHY_MTP
    for (int next_flow_id : next_flow_list){
      if (--indegree_mapping[next_flow_id] == 0) { 
        phy_ready(channel_id, next_flow_id);
      }
    }
    #else
    flow_exist = true;
    bool flow_send = false;
    bool recv_finished_tag = true;
    for (auto it = free_packets.begin(); it != free_packets.end(); it++) {
      if (it->second != 0) {
        recv_finished_tag = false;
        break;
      }
    }
    NcclLog->writeLog(NcclLogLevel::DEBUG,"next_flow_list.size %d",next_flow_list.size());
    for (int next_flow_id : next_flow_list) {
      NcclTreeFlowModel::FlowCriticalSection cs;
      if (indegree_mapping.count(next_flow_id) == 0) {
        flow_exist = false;
        cs.ExitSection();
        break;
      }
      if (--indegree_mapping[next_flow_id] == 0) {
        MockNccl::SingleFlow cur_flow = _flow_models[std::make_pair(channel_id, next_flow_id)];
          cs.ExitSection();
          insert_packets(channel_id, next_flow_id);
      }else{
        cs.ExitSection();
      }
    }
    assert(flow_exist = true);
    #endif
  } else if (event == EventType::StreamInit) {
    #ifdef PHY_MTP
    MPI_Barrier(MPI_COMM_WORLD);
    for(auto single_flow: _flow_models){
      if((single_flow.second.src==id||single_flow.second.dest==id)){ 
        #ifdef PHY_RDMA
        flow_rdma.ibv_create_peer_qp(id,single_flow.second.channel_id,single_flow.second.src,single_flow.second.dest,single_flow.second.chunk_count + 1 ,single_flow.second.chunk_id,single_flow.second.flow_size);
        #endif
      }
    }
    MPI_Barrier(MPI_COMM_WORLD);
    auto now = std::chrono::system_clock::now();
    auto now_us = std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count();
    NcclLog->writeLog(NcclLogLevel::DEBUG,"streamInit time %lld",now_us);
    start_time = std::chrono::high_resolution_clock::now();
    #endif
    for (int i = 0; i < parallel_reduce; i++) {
      #ifndef PHY_MTP
      init_recv_ready();
      #endif
      for(int j = 0; j < m_channels; j ++) {
        for(const auto flow_model : _flow_models) {
          if(flow_model.second.src!=id)continue;
          std::vector<int> parent_list = flow_model.second.parent_flow_id;
          if((parent_list.size() == 0 ) && flow_model.second.channel_id == j ) {
            #ifdef PHY_MTP
            if(flow_model.second.chunk_id == 0){
              phy_ready(j, flow_model.second.flow_id);
            }
            #else
            if (flow_model.second.chunk_id == 0) {
              pQps->peer_qps[std::make_pair(
                  flow_model.second.channel_id,
                  std::make_pair(
                      flow_model.second.src, flow_model.second.dest))] = 0;
              insert_packets(j,flow_model.second.flow_id);
            } else {
              pQps->peer_wating_tasks[std::make_pair(
                      flow_model.second.channel_id,
                      std::make_pair(
                          flow_model.second.src, flow_model.second.dest))]
                  .push(flow_model.second.flow_id);
            }
            #endif
          }
        }
      }
      #ifdef PHY_MTP
      waiting_to_exit();
      NcclLog->writeLog(NcclLogLevel::DEBUG, "NcclTreeFlowModel::waiting_to_exit end ");
      #endif
    }
  } else if(event == EventType::PacketSentFinshed){
    SendPacketEventHandlerData* rcehd = (SendPacketEventHandlerData*)ehd;
    AstraSim::ncclFlowTag flowTag = rcehd->flowTag;
    int sent_flow_id = flowTag.current_flow_id;
    int channel_id = flowTag.channel_id;
    std::vector<int> next_flow_list = flowTag.tree_flow_list;   
    NcclLog->writeLog(NcclLogLevel::DEBUG,"PacketSentFinshed src %d dst %d channel_id %d flow_id %d",flowTag.sender_node,flowTag.receiver_node,flowTag.channel_id,flowTag.current_flow_id);
    reduce(channel_id,sent_flow_id);
    bool flow_exist = next_flow_list.size() == 0 ? true : false;
    #ifndef PHY_MTP
    NcclTreeFlowModel::FlowCriticalSection cs;
    pQps->peer_qps[std::make_pair(flowTag.channel_id,std::make_pair(flowTag.sender_node,flowTag.receiver_node))]=1;
    cs.ExitSection();
    if(pQps->peer_wating_tasks[std::make_pair(flowTag.channel_id,std::make_pair(flowTag.sender_node,flowTag.receiver_node))].size()>0){
      int cur_flow_id = pQps->peer_wating_tasks[std::make_pair(flowTag.channel_id,std::make_pair(flowTag.sender_node,flowTag.receiver_node))].front();
      pQps->peer_wating_tasks[std::make_pair(flowTag.channel_id,std::make_pair(flowTag.sender_node,flowTag.receiver_node))].pop();
      pQps->peer_qps[std::make_pair(flowTag.channel_id,std::make_pair(flowTag.sender_node,flowTag.receiver_node))]=0;
      insert_packets(channel_id,cur_flow_id);
    }
    iteratable(channel_id); 
    #else
    phy_iteratable(channel_id);
    #endif
  }
}