Tick Layer::compute_time()

in astra-sim-alibabacloud/astra-sim/workload/Layer.cc [855:966]


Tick Layer::compute_time(
    ComType comtype,
    int tp_size,
    int nranks,
    uint64_t data_size,
    MockNccl::GroupType group_type,
    int all_gpus,
    int ep_size) {
  UserParam* param = UserParam::getInstance();
  Tick comp_time = 0;
  if (comtype == ComType::None) {
    return 0;
  }

    bool DP_comm_inside = false;
    bool TP_comm_inside = false;
    bool EP_comm_inside = false;
    int n_ranks;
    int nnics;
    uint32_t  gpus_per_server = param->net_work_param.gpus_per_server;
    GPUType gpu_type = param->net_work_param.gpu_type;
    float tp_ar = param->net_work_param.tp_ar;
    float tp_ag = param->net_work_param.tp_ag;
    float tp_ata = param->net_work_param.tp_ata;
    float ep_ata = param->net_work_param.ep_ata;
    float dp_ag = param->net_work_param.dp_ag;
    float ep_ag = param->net_work_param.ep_ag;
    float dp_ar = param->net_work_param.dp_ar;
    float ep_ar = param->net_work_param.ep_ar;
    if (group_type == MockNccl::GroupType::TP || group_type == MockNccl::GroupType::EP) {
      n_ranks = tp_size;
      if (n_ranks <= gpus_per_server)
        TP_comm_inside = true;
    } else if (
        group_type == MockNccl::GroupType::DP ||
        group_type == MockNccl::GroupType::EP || group_type == MockNccl::GroupType::DP_EP) {
      n_ranks = nranks;
      nnics = gpus_per_server / tp_size;
      if (all_gpus == gpus_per_server && tp_size <= gpus_per_server)
        DP_comm_inside = true;
        
    }
    if (TP_comm_inside || DP_comm_inside) {
      if (comtype == ComType::All_Reduce) {

        comp_time = data_size * GBps / tp_ar * 1e9 * 2 * //tp2 ep8 164.8 tp16 218 
            (nranks - 1) / (nranks / 1.0);
      }
       else if (group_type == MockNccl::GroupType::TP && (
          comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter )) {
          comp_time = data_size * GBps / tp_ag * 1e9 * 
              (nranks - 1) / (nranks / 1.0);
      } else if (group_type == MockNccl::GroupType::TP && (
          comtype == ComType::All_to_All)) {
            comp_time = data_size * GBps / tp_ata * 1e9 * 
                (nranks - 1) / (nranks / 1.0);
      }else if (group_type == MockNccl::GroupType::EP && 
        comtype == ComType::All_to_All) {

            comp_time = data_size * GBps / ep_ata * 1e9 * 
                (nranks - 1) / (nranks / 1.0);

      }else {
        comp_time = 0;
      }
    } else if (!TP_comm_inside && group_type == MockNccl::GroupType::TP) {
      if (comtype == ComType::All_Reduce) {
            comp_time = data_size  * GBps /
                tp_ar * 1e9 * 2 *
                (nranks - 1) / (nranks / 1.0);
      } else if (
          comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter) {
            comp_time = data_size * GBps /
                tp_ag * 1e9 *
                (nranks - 1) / (nranks / 1.0);
      } else if (
          comtype == ComType::All_to_All) {
            comp_time = data_size * GBps /
                tp_ata * 1e9 *
                (nranks - 1) / (nranks / 1.0);
      } else {
        comp_time = 0;
      }
    } else if (
        !DP_comm_inside &&
        (group_type == MockNccl::GroupType::DP)) {
      if (comtype == ComType::All_Reduce) {
            comp_time = data_size  * GBps / dp_ar * 1e9 * 
                2 * (nranks - 1) / (nranks / 1.0);
      } else if (
          comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter || comtype == ComType::All_to_All) {
            comp_time = data_size * GBps / dp_ag * 1e9 * //tp2 ep8 48.5
                (nranks - 1) / (nranks / 1.0);
      } else {
        comp_time = 0;
      }
    }else if (
        !DP_comm_inside &&
        ( group_type == MockNccl::GroupType::DP_EP)) {
      if (comtype == ComType::All_Reduce) {
            comp_time = data_size * GBps / ep_ar* 1e9 * 
                2 * (nranks - 1) / (nranks / 1.0);
      } else if (
          comtype == ComType::All_Gather || comtype == ComType::Reduce_Scatter || comtype == ComType::All_to_All) {
            comp_time = data_size * GBps / ep_ag * 1e9 * //tp2 ep8 48.5
                (nranks - 1) / (nranks / 1.0);
      } else {
        comp_time = 0;
      }
    }
  return comp_time;
}