astra-sim-alibabacloud/astra-sim/system/PhyMultiThread.cc (199 lines of code) (raw):
/*
*Copyright (c) 2024, Alibaba Group;
*Licensed under the Apache License, Version 2.0 (the "License");
*you may not use this file except in compliance with the License.
*You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*Unless required by applicable law or agreed to in writing, software
*distributed under the License is distributed on an "AS IS" BASIS,
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*See the License for the specific language governing permissions and
*limitations under the License.
*/
#include<chrono>
#include "PhyMultiThread.hh"
extern FlowPhyRdma flow_rdma;
std::atomic<bool> PhyMtpInterface::g_e_inCriticalSection (false);
std::map<int,std::atomic<int>> all_recv_size;
std::map<int,std::atomic<int>> all_send_size;
bool end_flag = false;
void (*send_finished_callback)(AstraSim::ncclFlowTag flowTag);
void (*receive_finished_callback)(AstraSim::ncclFlowTag flowTag);
void
set_send_finished_callback(void (*msg_handler)(AstraSim::ncclFlowTag flowTag)){
send_finished_callback = msg_handler;
}
void
set_receive_finished_callback(void (*msg_handler)(AstraSim::ncclFlowTag flowTag)){
receive_finished_callback = msg_handler;
}
static void
insert_recv_cqe(void* buff) {
MockNcclLog* NcclLog = MockNcclLog::getInstance();
TransportData* ptrrecvdata = reinterpret_cast<TransportData*> (buff);
AstraSim::ncclFlowTag flowTag = AstraSim::ncclFlowTag(
ptrrecvdata->channel_id,
ptrrecvdata->chunk_id,
ptrrecvdata->current_flow_id,
ptrrecvdata->child_flow_id,
ptrrecvdata->sender_node,
ptrrecvdata->receiver_node,
ptrrecvdata->flow_size,
ptrrecvdata->pQps,
ptrrecvdata->tag_id,
ptrrecvdata->nvls_on);
NcclLog->writeLog(NcclLogLevel::DEBUG,"PhyMultiThread.cc::insert_recv_cqe src_id %d dst_id %d flow_id %d channel_id %d",flowTag.sender_node,flowTag.receiver_node,flowTag.current_flow_id,flowTag.channel_id);
flowTag.tree_flow_list.clear();
for(int i =0;i<ptrrecvdata->child_flow_size;i++){
flowTag.tree_flow_list.push_back(ptrrecvdata->child_flow_list[i]);
}
receive_finished_callback(flowTag);
}
static void
insert_send_cqe(void* buff) {
MockNcclLog* NcclLog = MockNcclLog::getInstance();
TransportData* ptrrecvdata = reinterpret_cast<TransportData*> (buff);
AstraSim::ncclFlowTag flowTag = AstraSim::ncclFlowTag(
ptrrecvdata->channel_id,
ptrrecvdata->chunk_id,
ptrrecvdata->current_flow_id,
ptrrecvdata->child_flow_id,
ptrrecvdata->sender_node,
ptrrecvdata->receiver_node,
ptrrecvdata->flow_size,
ptrrecvdata->pQps,
ptrrecvdata->tag_id,
ptrrecvdata->nvls_on);
NcclLog->writeLog(NcclLogLevel::DEBUG,"PhyMultiThread.cc::insert_send_cqe src_id %d dst_id %d flow_id %d channel_id %d",flowTag.sender_node,flowTag.receiver_node,flowTag.current_flow_id,flowTag.channel_id);
flowTag.tree_flow_list.clear();
for(int i =0;i<ptrrecvdata->child_flow_size;i++){
flowTag.tree_flow_list.push_back(ptrrecvdata->child_flow_list[i]);
}
send_finished_callback(flowTag);
}
static bool
judge_polling_all_recv_cqe(void * buff){
MockNcclLog* NcclLog = MockNcclLog::getInstance();
TransportData* ptrrecvdata = reinterpret_cast<TransportData*> (buff);
AstraSim::ncclFlowTag flowTag = AstraSim::ncclFlowTag(
ptrrecvdata->channel_id,
ptrrecvdata->chunk_id,
ptrrecvdata->current_flow_id,
ptrrecvdata->child_flow_id,
ptrrecvdata->sender_node,
ptrrecvdata->receiver_node,
ptrrecvdata->flow_size,
ptrrecvdata->pQps,
ptrrecvdata->tag_id,
ptrrecvdata->nvls_on);
int temp = 0;
{
MockNcclLog*NcclLog = MockNcclLog::getInstance();
if (!all_recv_size.count(flowTag.current_flow_id)) {
all_recv_size[flowTag.current_flow_id] = 1;
} else {
all_recv_size[flowTag.current_flow_id]++;
}
temp = all_recv_size[flowTag.current_flow_id];
NcclLog->writeLog(NcclLogLevel::DEBUG,"judge_polling_all_recv_cqe flow_id %d recv_cqe_size %d",flowTag.current_flow_id,temp);
}
if (temp == NCCL_QPS_PER_PEER) {
return true;
} else {
return false;
}
}
static bool
judge_polling_all_send_cqe(void * buff){
MockNcclLog* NcclLog = MockNcclLog::getInstance();
TransportData* ptrrecvdata = reinterpret_cast<TransportData*> (buff);
AstraSim::ncclFlowTag flowTag = AstraSim::ncclFlowTag(
ptrrecvdata->channel_id,
ptrrecvdata->chunk_id,
ptrrecvdata->current_flow_id,
ptrrecvdata->child_flow_id,
ptrrecvdata->sender_node,
ptrrecvdata->receiver_node,
ptrrecvdata->flow_size,
ptrrecvdata->pQps,
ptrrecvdata->tag_id,
ptrrecvdata->nvls_on);
int temp = 0;
{
MockNcclLog*NcclLog = MockNcclLog::getInstance();
if (!all_send_size.count(flowTag.current_flow_id)) {
all_send_size[flowTag.current_flow_id] = 1;
} else {
all_send_size[flowTag.current_flow_id]++;
}
temp = all_send_size[flowTag.current_flow_id];
NcclLog->writeLog(NcclLogLevel::DEBUG,"judge_polling_all_send_cqe flow_id %d send_cqe_size %d",flowTag.current_flow_id,temp);
}
if (temp == NCCL_QPS_PER_PEER) {
return true;
} else {
return false;
}
}
bool
create_polling_cqe_thread(void * cq_ptr,int lcore_id){
#ifdef PHY_RDMA
ibv_cq*cq = static_cast<ibv_cq*>(cq_ptr);
struct ibv_wc wc[TEST_IO_DEPTH] = {};
#endif
MockNcclLog* NcclLog = MockNcclLog::getInstance();
int ret = 0;
NcclLog->writeLog(NcclLogLevel::DEBUG,"PhyMultiThread.cc::create_polling_cqe_thread begin");
while (!end_flag)
{
#ifdef PHY_RDMA
memset(wc, 0, sizeof(wc));
ret = ibv_poll_cq(cq,TEST_IO_DEPTH,wc);
assert(ret>=0);
if(ret >0){
NcclLog->writeLog(NcclLogLevel::DEBUG,"PhyMultiThread.cc::create_polling_send_cqe_thread cqe num %d",ret);
for (int i = 0; i < ret; i++) {
if (wc[i].status != IBV_WC_SUCCESS) {
NcclLog->writeLog(
NcclLogLevel::ERROR,
" wr's status is error %d opcode %d ",
wc[i].status,
wc[i].opcode);
}
assert(wc[i].status == IBV_WC_SUCCESS);
if (wc[i].opcode == IBV_WC_RECV ||
wc[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
std::chrono::time_point<std::chrono::high_resolution_clock>
start_time = std::chrono::high_resolution_clock::now();
auto now = std::chrono::system_clock::now();
auto now_us =
std::chrono::duration_cast<std::chrono::microseconds>(
now.time_since_epoch())
.count();
NcclLog->writeLog(
NcclLogLevel::DEBUG,
"poll_recv_cqe qpn %d wr_id %d chunk_id %d time %lld",
wc[i].qp_num,
wc[i].wr_id,
wc[i].imm_data,
now_us);
void* recv_buff = flow_rdma.recv_wr_id_to_buff(wc[i].qp_num, wc[i].wr_id,wc[i].imm_data);
if (judge_polling_all_recv_cqe(recv_buff)) {
insert_recv_cqe(recv_buff);
}
} else if (wc[i].opcode == IBV_WC_RDMA_WRITE) {
auto now = std::chrono::system_clock::now();
auto now_us =
std::chrono::duration_cast<std::chrono::microseconds>(
now.time_since_epoch())
.count();
NcclLog->writeLog(
NcclLogLevel::DEBUG,
"poll_send_cqe qpn %d wr_id %d time %lld",
wc[i].qp_num,
wc[i].wr_id,
now_us);
void* send_buff = flow_rdma.send_wr_id_to_buff(wc[i].qp_num, wc[i].wr_id);
if (judge_polling_all_send_cqe(send_buff)) {
insert_send_cqe(send_buff);
}
}
}
}
#endif
}
}
void
notify_all_thread_finished(){
MockNcclLog* NcclLog = MockNcclLog::getInstance();
end_flag = true;
NcclLog->writeLog(NcclLogLevel::DEBUG,"PhyMultiThread::notify_all_thread_finished end");
}