in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [855:966]
Tick Layer::compute_time(
ComType comtype,
int tp_size,
int nranks,
uint64_t data_size,
MockNccl::GroupType group_type,
int all_gpus,
int ep_size) {
UserParam* param = UserParam::getInstance();
Tick comp_time = 0;
if (comtype == ComType::None) {
return 0;
}
bool DP_comm_inside = false;
bool TP_comm_inside = false;
bool EP_comm_inside = false;
int n_ranks;
int nnics;
uint32_t gpus_per_server = param->net_work_param.gpus_per_server;
GPUType gpu_type = param->net_work_param.gpu_type;
float tp_ar = param->net_work_param.tp_ar;
float tp_ag = param->net_work_param.tp_ag;
float tp_ata = param->net_work_param.tp_ata;
float ep_ata = param->net_work_param.ep_ata;
float dp_ag = param->net_work_param.dp_ag;
float ep_ag = param->net_work_param.ep_ag;
float dp_ar = param->net_work_param.dp_ar;
float ep_ar = param->net_work_param.ep_ar;
if (group_type == MockNccl::GroupType::TP || group_type == MockNccl::GroupType::EP) {
n_ranks = tp_size;
if (n_ranks <= gpus_per_server)
TP_comm_inside = true;
} else if (
group_type == MockNccl::GroupType::DP ||
group_type == MockNccl::GroupType::EP || group_type == MockNccl::GroupType::DP_EP) {
n_ranks = nranks;
nnics = gpus_per_server / tp_size;
if (all_gpus == gpus_per_server && tp_size <= gpus_per_server)
DP_comm_inside = true;
}
if (TP_comm_inside || DP_comm_inside) {
if (comtype == ComType::All_Reduce) {
comp_time = data_size * GBps / tp_ar * 1e9 * 2 * //tp2 ep8 164.8 tp16 218
(nranks - 1) / (nranks / 1.0);
}
else if (group_type == MockNccl::GroupType::TP && (
comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter )) {
comp_time = data_size * GBps / tp_ag * 1e9 *
(nranks - 1) / (nranks / 1.0);
} else if (group_type == MockNccl::GroupType::TP && (
comtype == ComType::All_to_All)) {
comp_time = data_size * GBps / tp_ata * 1e9 *
(nranks - 1) / (nranks / 1.0);
}else if (group_type == MockNccl::GroupType::EP &&
comtype == ComType::All_to_All) {
comp_time = data_size * GBps / ep_ata * 1e9 *
(nranks - 1) / (nranks / 1.0);
}else {
comp_time = 0;
}
} else if (!TP_comm_inside && group_type == MockNccl::GroupType::TP) {
if (comtype == ComType::All_Reduce) {
comp_time = data_size * GBps /
tp_ar * 1e9 * 2 *
(nranks - 1) / (nranks / 1.0);
} else if (
comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter) {
comp_time = data_size * GBps /
tp_ag * 1e9 *
(nranks - 1) / (nranks / 1.0);
} else if (
comtype == ComType::All_to_All) {
comp_time = data_size * GBps /
tp_ata * 1e9 *
(nranks - 1) / (nranks / 1.0);
} else {
comp_time = 0;
}
} else if (
!DP_comm_inside &&
(group_type == MockNccl::GroupType::DP)) {
if (comtype == ComType::All_Reduce) {
comp_time = data_size * GBps / dp_ar * 1e9 *
2 * (nranks - 1) / (nranks / 1.0);
} else if (
comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter || comtype == ComType::All_to_All) {
comp_time = data_size * GBps / dp_ag * 1e9 * //tp2 ep8 48.5
(nranks - 1) / (nranks / 1.0);
} else {
comp_time = 0;
}
}else if (
!DP_comm_inside &&
( group_type == MockNccl::GroupType::DP_EP)) {
if (comtype == ComType::All_Reduce) {
comp_time = data_size * GBps / ep_ar* 1e9 *
2 * (nranks - 1) / (nranks / 1.0);
} else if (
comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter || comtype == ComType::All_to_All) {
comp_time = data_size * GBps / ep_ag * 1e9 * //tp2 ep8 48.5
(nranks - 1) / (nranks / 1.0);
} else {
comp_time = 0;
}
}
return comp_time;
}