in astra-sim-alibabacloud/astra-sim/network_frontend/ns3/AstraSimNetwork.cc [134:208]
virtual int sim_recv(void *buffer, uint64_t count, int type, int src, int tag,
AstraSim::sim_request *request,
void (*msg_handler)(void *fun_arg), void *fun_arg) {
#ifdef NS3_MTP
MtpInterface::explicitCriticalSection cs;
#endif
MockNcclLog* NcclLog = MockNcclLog::getInstance();
AstraSim::ncclFlowTag flowTag = request->flowTag;
src += npu_offset;
task1 t;
t.src = src;
t.dest = rank;
t.count = count;
t.type = 1;
t.fun_arg = fun_arg;
t.msg_handler = msg_handler;
AstraSim::RecvPacketEventHadndlerData* ehd = (AstraSim::RecvPacketEventHadndlerData*) t.fun_arg;
AstraSim::EventType event = ehd->event;
tag = ehd->flowTag.tag_id;
NcclLog->writeLog(NcclLogLevel::DEBUG,"接收事件注册 src %d sim_recv on rank %d tag_id %d channdl id %d",src,rank,tag,ehd->flowTag.channel_id);
if (recvHash.find(make_pair(tag, make_pair(t.src, t.dest))) !=
recvHash.end()) {
uint64_t count = recvHash[make_pair(tag, make_pair(t.src, t.dest))];
if (count == t.count) {
recvHash.erase(make_pair(tag, make_pair(t.src, t.dest)));
assert(ehd->flowTag.child_flow_id == -1 && ehd->flowTag.current_flow_id == -1);
if(receiver_pending_queue.count(std::make_pair(std::make_pair(rank, src),tag))!= 0) {
AstraSim::ncclFlowTag pending_tag = receiver_pending_queue[std::make_pair(std::make_pair(rank, src),tag)];
receiver_pending_queue.erase(std::make_pair(std::make_pair(rank,src),tag));
ehd->flowTag = pending_tag;
}
#ifdef NS3_MTP
cs.ExitSection();
#endif
t.msg_handler(t.fun_arg);
goto sim_recv_end_section;
} else if (count > t.count) {
recvHash[make_pair(tag, make_pair(t.src, t.dest))] = count - t.count;
assert(ehd->flowTag.child_flow_id == -1 && ehd->flowTag.current_flow_id == -1);
if(receiver_pending_queue.count(std::make_pair(std::make_pair(rank, src),tag))!= 0) {
AstraSim::ncclFlowTag pending_tag = receiver_pending_queue[std::make_pair(std::make_pair(rank, src),tag)];
receiver_pending_queue.erase(std::make_pair(std::make_pair(rank,src),tag));
ehd->flowTag = pending_tag;
}
#ifdef NS3_MTP
cs.ExitSection();
#endif
t.msg_handler(t.fun_arg);
goto sim_recv_end_section;
} else {
recvHash.erase(make_pair(tag, make_pair(t.src, t.dest)));
t.count -= count;
expeRecvHash[make_pair(tag, make_pair(t.src, t.dest))] = t;
}
} else {
if (expeRecvHash.find(make_pair(tag, make_pair(t.src, t.dest))) ==
expeRecvHash.end()) {
expeRecvHash[make_pair(tag, make_pair(t.src, t.dest))] = t;
NcclLog->writeLog(NcclLogLevel::DEBUG," 网络包后到,先进行注册 recvHash do not find expeRecvHash.new make src %d dest %d t.count: %d channel_id %d current_flow_id %d",t.src,t.dest,t.count,tag,flowTag.current_flow_id);
} else {
uint64_t expecount =
expeRecvHash[make_pair(tag, make_pair(t.src, t.dest))].count;
NcclLog->writeLog(NcclLogLevel::DEBUG," 网络包后到,重复注册 recvHash do not find expeRecvHash.add make src %d dest %d expecount: %d t.count: %d tag_id %d current_flow_id %d",t.src,t.dest,expecount,t.count,tag,flowTag.current_flow_id);
}
}
#ifdef NS3_MTP
cs.ExitSection();
#endif
sim_recv_end_section:
return 0;
}