bool NcclTreeFlowModel::ready()

in astra-sim-alibabacloud/astra-sim/system/collective/NcclTreeFlowModel.cc [513:606]


bool NcclTreeFlowModel::ready(int channel_id, int flow_id) {
  MockNcclLog* NcclLog = MockNcclLog::getInstance();
  MyPacket packet;
  {
    if (stream->state == StreamState::Created ||
        stream->state == StreamState::Ready) {
      stream->changeState(StreamState::Executing);
    }
    if (!enabled || packets[std::make_pair(channel_id, flow_id)].size() == 0 || _stream_count[channel_id] == 0) {
      NcclLog->writeLog(NcclLogLevel::DEBUG,"NcclTreeFlowModel not ready!");
      return false;
    }
    packet = packets[std::make_pair(channel_id, flow_id)].front();
  }
  std::vector<int>recv_prevs;
  recv_prevs = _flow_models[std::make_pair(channel_id, flow_id)].prev;
  for (int recv_prev : recv_prevs) {
    sim_request rcv_req;
    rcv_req.vnet = this->stream->current_queue_id;
    rcv_req.layerNum = layer_num;
    rcv_req.reqCount = packet.msg_size;
    rcv_req.tag = channel_id;
    RecvPacketEventHadndlerData* ehd = new RecvPacketEventHadndlerData(
        stream,
        stream->owner->id,
        EventType::PacketReceived,
        packet.preferred_vnet,
        packet.stream_num);
    ehd->flowTag.child_flow_id = -1;
    ehd->flowTag.current_flow_id = -1;
    auto flow_model = this->_flow_models[std::make_pair(channel_id,flow_id)];
    if(flow_model.parent_flow_id.size()==0 || flow_model.conn_type == "RING"){
      ehd->flowTag.tag_id = layer_num*flow_model.chunk_count*m_channels + flow_model.chunk_count*flow_model.channel_id+flow_model.chunk_id;
    }else{
      ehd->flowTag.tag_id = layer_num*flow_model.chunk_count*m_channels + flow_model.chunk_count*flow_model.channel_id+flow_model.chunk_id+1;
    }
    ehd->flowTag.channel_id = packet.channel_id;
    if (free_packets[std::make_pair(channel_id, recv_prev)] > 0) {
      stream->owner->front_end_sim_recv(
          0,
          Sys::dummy_data,
          rcv_req.reqCount,
          UINT8,
          recv_prev,
          rcv_req.tag,
          &rcv_req,
          &Sys::handleEvent,
          ehd);
    }
  }
  sim_request snd_req;
  snd_req.srcRank = id;
  snd_req.dstRank = packet.preferred_dest;
  snd_req.tag = channel_id;
  snd_req.reqType = UINT8;
  snd_req.vnet = this->stream->current_queue_id;
  snd_req.layerNum = layer_num;
  snd_req.reqCount = packet.msg_size;
  MockNccl::SingleFlow flow_model =
      this->_flow_models[std::make_pair(channel_id, flow_id)];
  snd_req.flowTag.tag_id = layer_num * flow_model.chunk_count * m_channels +
      flow_model.channel_id * flow_model.chunk_count + flow_model.chunk_id;
  snd_req.flowTag.channel_id = channel_id;
  snd_req.flowTag.flow_size = flow_model.flow_size;
  snd_req.flowTag.current_flow_id = flow_id;
  snd_req.flowTag.chunk_id = flow_model.chunk_id;
  snd_req.flowTag.child_flow_id = -1;
  snd_req.flowTag.tree_flow_list =
      this->_flow_models[std::make_pair(channel_id, flow_id)].child_flow_id;
  snd_req.flowTag.sender_node = id;
  snd_req.flowTag.receiver_node = packet.preferred_dest;
  snd_req.flowTag.pQps = this->pQps;
  if (this->comType == ComType::All_Reduce_NVLS)
    snd_req.flowTag.nvls_on = true;
  else
    snd_req.flowTag.nvls_on = false;
  SendPacketEventHandlerData* send_ehd = new SendPacketEventHandlerData(
      stream,
      id,
      packet.preferred_dest,
      channel_id,
      EventType::PacketSentFinshed);
  stream->owner->front_end_sim_send(
      0,
      Sys::dummy_data,
      snd_req.reqCount,
      UINT8,
      packet.preferred_dest,
      snd_req.flowTag.tag_id,
      &snd_req,
      &Sys::handleEvent,
      send_ehd);
  return true;
}