bool Workload::initialize_workload()

in astra-sim-alibabacloud/astra-sim/workload/Workload.cc [1134:1549]


bool Workload::initialize_workload(std::string name) {
  std::map<int, bool> chekpoints;
  std::map<int, bool> need_checkpoint_initiation;
  std::ifstream inFile;
  inFile.open(name);
  if (!inFile) {
    std::cerr << "Unable to open file: " << name << std::endl;
    std::cerr << "######### Exiting because unable to open the workload input "
                 "file #########"
              << std::endl;
    std::cerr << "This error is fatal. Please check your path and filename."
              << std::endl;
    exit(1);
  } else {
    if (generator->id == 0) {
      std::cout << "Success in opening workload file" << std::endl;
    }
  }
 std::string firstline;
  std::getline(inFile,firstline);
  // std::cout << "First line is : '" << firstline << "'" << std::endl;
  std::istringstream iss(firstline);
  std:string token;
  std::vector<std::string> tokens;
  // bool findparallesimPolcy = false;
  
  while (iss >> token) {
        tokens.push_back(token);
        // std::cout << "Token is : '" << token << "'" << std::endl;
    }



  if(!tokens.empty()){
    parallelismPolicy = decode_parallelsim(tokens[0]);
  }

  if (parallelismPolicy == ParallelismPolicy::TransformerFwdInBckwd ||
      parallelismPolicy == ParallelismPolicy::Transformer) {
        for (size_t i = 1; i < tokens.size(); i = i+1){
          if(tokens[i]=="model_parallel_NPU_group:"){
            model_parallel_npu_group = std::stoi(tokens[i+1]);
            if (generator->id == 0) {
              std::cout <<"model_parallel_NPU_group is " << model_parallel_npu_group << std::endl;
            }
          }else if(tokens[i]=="ep:"){
            expert_parallel_npu_group = std::stoi(tokens[i+1]);
          }else if(tokens[i]== "pp:"){
            pipeline_model_parallelism = std::stoi(tokens[i+1]);
          }else if(tokens[i]=="vpp:"){
            vpp = std::stoi(tokens[i+1]);
          }else if(tokens[i]=="ga:"){
            GA = std::stoi(tokens[i+1]);
          }else if(tokens[i]=="all_gpus:"){
            all_gpus = std::stoi(tokens[i+1]);
          }
        }

        if(parallelismPolicy == ParallelismPolicy::TransformerFwdInBckwd){
          if (generator->id == 0) {
            std::cout << "checkpoints layers are: ";
          }
          for(size_t i = 1; i < tokens.size(); i = i+1){
            if(tokens[i]=="checkpoints:"){
              int account = std::stoi(tokens[i+1]);
              while(account-- >0){
                int j = 2;
                int layer = std::stoi(tokens[i+j]);
                chekpoints[layer] = true;
                if (generator->id == 0) {
                  std::cout << layer << ", ";
                }
                j++;
              }
                
            }else if(tokens[i]=="checkpoint_initiates:"){
                if (generator->id == 0) {
                  std::cout << std::endl;
                  std::cout << "layers initiating fwd_in_bckwd are: ";
                }
                int account = std::stoi(tokens[i+1]);
                while(account-- >0){
                  int j = 2;
                  int layer = std::stoi(tokens[i+j]);
                  need_checkpoint_initiation[layer] = true;
                  if (generator->id == 0) {
                    std::cout << layer << ", ";
                  }
                  j++;
                }
                if (generator->id == 0) {
                  std::cout << std::endl;
                }
              }
            }
          }
      }else if(parallelismPolicy == ParallelismPolicy::DLRM ||
                parallelismPolicy == ParallelismPolicy::DLRMEnhanced){
                  for (size_t i = 1; i < tokens.size(); i = i+1){
                    if(tokens[i]=="DLRM_LAST_BOTTOM_LAYER:"){
                      DLRM_LAST_BOTTOM_LAYER = std::stoi(tokens[i+1]);
                    }
                  }
                if (generator->id == 0) {
                  std::cout
                  << "****************** info: DLRM workload last bottom layer is: "
                  << DLRM_LAST_BOTTOM_LAYER << std::endl;
                }
        }else if (parallelismPolicy == ParallelismPolicy::None) {
          #ifndef PHY_MTP
          std::cerr << "######### Exiting because unable to decode the workload "
                 "parallelization strategy #########"
                  << std::endl;
          inFile.close();
          exit(1);
          #else
          parallelismPolicy = ParallelismPolicy::TransformerFwdInBckwd;
          #endif
  }
  std::map<std::string, std::vector<bool>> general_involved_dimensions =
      decode_involved_dimensions(parallelismPolicy, model_parallel_npu_group);
  pp_commsize = 0;
  for (size_t i = 1; i < tokens.size(); i = i+1){
    if(tokens[i]=="pp_comm"||tokens[i]=="pp_comm:"){
      pp_commsize = std::stoi(tokens[i+1]);
    }
  }
  if (generator->id == 0) {
      std::cout <<"pp_commize:"<< pp_commsize << std::endl;
  }
  if(generator->id == 0){
    if (model_parallel_npu_group == 0 || expert_parallel_npu_group == 0 || pipeline_model_parallelism == 0 
        || vpp==0 || GA == 0 || all_gpus == 0 ||(pipeline_model_parallelism !=1 && pp_commsize ==0)||(pipeline_model_parallelism == 1 && pp_commsize !=0)){
          std::cerr << "*****Warining: Input workload format mismatch. It may cause simulation error. Pleased use the latest AICB to generate.*****" << std::endl;
      }
  }        
  run_type = tokens[0];
  std::string secondline;
  std::getline(inFile,secondline);

  int lines;
  // std::cout << "Second line content: '" << secondline << "'" << std::endl;
  lines = std::stoi(secondline);


  SIZE = lines;
  layers = new Layer*[SIZE];
  for (int i = 0; i < lines; i++) {
    std::string id;
    inFile >> id;
    int depen;
    inFile >> depen;

    Tick fp_compute_time;
    inFile >> fp_compute_time;
    std::string fp_comm_type_s;
    inFile >> fp_comm_type_s;
    uint64_t fp_comm_size;
    inFile >> fp_comm_size;

    Tick ig_compute_time;
    inFile >> ig_compute_time;
    std::string ig_comm_type_s;
    inFile >> ig_comm_type_s;
    uint64_t ig_comm_size;
    inFile >> ig_comm_size;

    Tick wg_compute_time;
    inFile >> wg_compute_time;
    std::string wg_comm_type_s;
    inFile >> wg_comm_type_s;
    uint64_t wg_comm_size;
    inFile >> wg_comm_size;
    Tick wg_update_time;
    inFile >> wg_update_time;

    ParallelismPolicy specific_policy = ParallelismPolicy::None;
    std::map<std::string, std::vector<bool>> selected_involved_dimensions;
    ComType fp_type = ComType::None;
    ComType ig_type = ComType::None;
    ComType wg_type = ComType::None;
    MockNccl::GroupType fp_group_type = MockNccl::GroupType::NONE;
    MockNccl::GroupType ig_group_type = MockNccl::GroupType::NONE;
    MockNccl::GroupType wg_group_type = MockNccl::GroupType::NONE;
    if (wg_comm_type_s.substr(0,9) == "ALLREDUCE") {
      wg_type = ComType::All_Reduce;
      if(wg_comm_type_s == "ALLREDUCE"){
        wg_group_type = MockNccl::GroupType::DP;
      } else if(wg_comm_type_s == "ALLREDUCE_EP"){
        wg_group_type = MockNccl::GroupType::EP;
      } else if(wg_comm_type_s == "ALLREDUCE_DP_EP"){
        wg_group_type = MockNccl::GroupType::DP_EP;
      } else{
        wg_group_type = MockNccl::GroupType::NONE;
      }
    } else if (wg_comm_type_s.substr(0,8) == "ALLTOALL") {
      wg_type = ComType::All_to_All;
      if(wg_comm_type_s == "ALLTOALL"){
        wg_group_type = MockNccl::GroupType::DP;
      } else if(wg_comm_type_s == "ALLTOALL_EP"){
        wg_group_type = MockNccl::GroupType::EP;
      } else if(wg_comm_type_s == "ALLTOALL_DP_EP"){
        wg_group_type = MockNccl::GroupType::DP_EP;
      } else{
        wg_group_type = MockNccl::GroupType::NONE;
      }
    } else if (wg_comm_type_s.substr(0,17) == "ALLREDUCEALLTOALL") {
      wg_type = ComType::All_Reduce_All_to_All;
      if(wg_comm_type_s == "ALLREDUCEALLTOALL"){
        wg_group_type = MockNccl::GroupType::DP;
      } else if(wg_comm_type_s == "ALLREDUCEALLTOALL_EP"){
        wg_group_type = MockNccl::GroupType::EP;
      } else if(wg_comm_type_s == "ALLREDUCEALLTOALL_DP_EP"){
        wg_group_type = MockNccl::GroupType::DP_EP;
      } else{
        wg_group_type = MockNccl::GroupType::NONE;
      }
    } else if (wg_comm_type_s.substr(0,9) == "ALLGATHER") {
      wg_type = ComType::All_Gather;
      if(wg_comm_type_s == "ALLGATHER"){
        wg_group_type = MockNccl::GroupType::DP;
      } else if(wg_comm_type_s == "ALLGATHER_EP"){
        wg_group_type = MockNccl::GroupType::EP;
      } else if(wg_comm_type_s == "ALLGATHER_DP_EP"){
        wg_group_type = MockNccl::GroupType::DP_EP;
      } else{
        wg_group_type = MockNccl::GroupType::NONE;
      }
    } else if (wg_comm_type_s.substr(0,13) == "REDUCESCATTER") {
      wg_type = ComType::Reduce_Scatter;
      if(wg_comm_type_s == "REDUCESCATTER"){
        wg_group_type = MockNccl::GroupType::DP;
      } else if(wg_comm_type_s == "REDUCESCATTER_EP"){
        wg_group_type = MockNccl::GroupType::EP;
      } else if(wg_comm_type_s == "REDUCESCATTER_DP_EP"){
        wg_group_type = MockNccl::GroupType::DP_EP;
      } else{
        wg_group_type = MockNccl::GroupType::NONE;
      }
    }

    // generate flow model

    if (ig_comm_type_s.substr(0,9) == "ALLREDUCE") {
      ig_type = ComType::All_Reduce;
      if(ig_comm_type_s == "ALLREDUCE"){
        ig_group_type = MockNccl::GroupType::TP;
      } else if(ig_comm_type_s == "ALLREDUCE_EP"){
        ig_group_type = MockNccl::GroupType::EP;
      } else if(ig_comm_type_s == "ALLREDUCE_DP_EP"){
        ig_group_type = MockNccl::GroupType::DP_EP;
      } else{
        ig_group_type = MockNccl::GroupType::NONE;
      }
    } else if (ig_comm_type_s.substr(0,8) == "ALLTOALL") {
      ig_type = ComType::All_to_All;
      if(ig_comm_type_s == "ALLTOALL"){
        ig_group_type = MockNccl::GroupType::TP;
      } else if(ig_comm_type_s == "ALLTOALL_EP"){
        ig_group_type = MockNccl::GroupType::EP;
      } else if(ig_comm_type_s == "ALLTOALL_DP_EP"){
        ig_group_type = MockNccl::GroupType::DP_EP;
      } else{
        ig_group_type = MockNccl::GroupType::NONE;
      }
    } else if (ig_comm_type_s.substr(0,17) == "ALLREDUCEALLTOALL") {
      ig_type = ComType::All_Reduce_All_to_All;
      if(ig_comm_type_s == "ALLREDUCEALLTOALL"){
        ig_group_type = MockNccl::GroupType::TP;
      } else if(ig_comm_type_s == "ALLREDUCEALLTOALL_EP"){
        ig_group_type = MockNccl::GroupType::EP;
      } else if(ig_comm_type_s == "ALLREDUCEALLTOALL_DP_EP"){
        ig_group_type = MockNccl::GroupType::DP_EP;
      } else{
        ig_group_type = MockNccl::GroupType::NONE;
      }
    } else if (ig_comm_type_s.substr(0,9) == "ALLGATHER") {
      ig_type = ComType::All_Gather;
      if(ig_comm_type_s == "ALLGATHER"){
        ig_group_type = MockNccl::GroupType::TP;
      } else if(ig_comm_type_s == "ALLGATHER_EP"){
        ig_group_type = MockNccl::GroupType::EP;
      } else if(ig_comm_type_s == "ALLGATHER_DP_EP"){
        ig_group_type = MockNccl::GroupType::DP_EP;
      } else{
        ig_group_type = MockNccl::GroupType::NONE;
      }
    } else if (ig_comm_type_s.substr(0,13) == "REDUCESCATTER") {
      ig_type = ComType::Reduce_Scatter;
      if(ig_comm_type_s == "REDUCESCATTER"){
        ig_group_type = MockNccl::GroupType::TP;
      } else if(ig_comm_type_s == "REDUCESCATTER_EP"){
        ig_group_type = MockNccl::GroupType::EP;
      } else if(ig_comm_type_s == "REDUCESCATTER_DP_EP"){
        ig_group_type = MockNccl::GroupType::DP_EP;
      } else{
        ig_group_type = MockNccl::GroupType::NONE;
      }
    }

    if (fp_comm_type_s.substr(0,9) == "ALLREDUCE") {
      fp_type = ComType::All_Reduce;
      if(fp_comm_type_s == "ALLREDUCE"){
        fp_group_type = MockNccl::GroupType::TP;
      } else if(fp_comm_type_s == "ALLREDUCE_EP"){
        fp_group_type = MockNccl::GroupType::EP;
      } else if(fp_comm_type_s == "ALLREDUCE_DP_EP"){
        fp_group_type = MockNccl::GroupType::DP_EP;
      } else{
        fp_group_type = MockNccl::GroupType::NONE;
      }
    } else if (fp_comm_type_s.substr(0,8) == "ALLTOALL") {
      fp_type = ComType::All_to_All;
      if(fp_comm_type_s == "ALLTOALL"){
        fp_group_type = MockNccl::GroupType::TP;
      } else if(fp_comm_type_s == "ALLTOALL_EP"){
        fp_group_type = MockNccl::GroupType::EP;
      } else if(fp_comm_type_s == "ALLTOALL_DP_EP"){
        fp_group_type = MockNccl::GroupType::DP_EP;
      } else{
        fp_group_type = MockNccl::GroupType::NONE;
      }
    } else if (fp_comm_type_s.substr(0,17) == "ALLREDUCEALLTOALL") {
      fp_type = ComType::All_Reduce_All_to_All;
      if(fp_comm_type_s == "ALLREDUCEALLTOALL"){
        fp_group_type = MockNccl::GroupType::TP;
      } else if(fp_comm_type_s == "ALLREDUCEALLTOALL_EP"){
        fp_group_type = MockNccl::GroupType::EP;
      } else if(fp_comm_type_s == "ALLREDUCEALLTOALL_DP_EP"){
        fp_group_type = MockNccl::GroupType::DP_EP;
      } else{
        fp_group_type = MockNccl::GroupType::NONE;
      }
    } else if (fp_comm_type_s.substr(0,9) == "ALLGATHER") {
      fp_type = ComType::All_Gather;
      if(fp_comm_type_s == "ALLGATHER"){
        fp_group_type = MockNccl::GroupType::TP;
      } else if(fp_comm_type_s == "ALLGATHER_EP"){
        fp_group_type = MockNccl::GroupType::EP;
      } else if(fp_comm_type_s == "ALLGATHER_DP_EP"){
        fp_group_type = MockNccl::GroupType::DP_EP;
      } else{
        fp_group_type = MockNccl::GroupType::NONE;
      }
    } else if (fp_comm_type_s.substr(0,13) == "REDUCESCATTER") {
      fp_type = ComType::Reduce_Scatter;
      if(fp_comm_type_s == "REDUCESCATTER"){
        fp_group_type = MockNccl::GroupType::TP;
      } else if(fp_comm_type_s == "REDUCESCATTER_EP"){
        fp_group_type = MockNccl::GroupType::EP;
      } else if(fp_comm_type_s == "REDUCESCATTER_DP_EP"){
        fp_group_type = MockNccl::GroupType::DP_EP;
      } else{
        fp_group_type = MockNccl::GroupType::NONE;
      }
    }
    if (generator->id == 0) {
      std::cout << "id: " << id << " , depen: " << depen
                << " , wg_comp_time: " << wg_compute_time << std::endl;
    }
    if (parallelismPolicy == ParallelismPolicy::HybridCustomized) {
      std::string specific_parallelsim;
      inFile >> specific_parallelsim;
      specific_policy = decode_parallelsim(specific_parallelsim);
    }
    if ((parallelismPolicy == ParallelismPolicy::DLRM ||
         parallelismPolicy == ParallelismPolicy::DLRMEnhanced) &&
        i == 0) {
      specific_policy = ParallelismPolicy::All;
    }
    if (specific_policy != ParallelismPolicy::None) {
      selected_involved_dimensions =
          decode_involved_dimensions(specific_policy, model_parallel_npu_group);
    } else {
      selected_involved_dimensions = general_involved_dimensions;
    }
    Layer* l = new Layer(
        id,
        i,
        generator,
        this,
        fp_compute_time * generator->compute_scale,
        fp_type,
        fp_group_type,
        fp_comm_size * generator->comm_scale,
        selected_involved_dimensions["fwd"],
        ig_compute_time * generator->compute_scale,
        ig_type,
        ig_group_type,
        ig_comm_size * generator->comm_scale,
        selected_involved_dimensions["ig"],
        wg_compute_time * generator->compute_scale,
        wg_type,
        wg_group_type,
        wg_comm_size * generator->comm_scale,
        selected_involved_dimensions["wg"],
        wg_update_time,
        specific_policy);
    if (chekpoints.find(i) != chekpoints.end()) {
      l->is_checkpoint = true;
    }
    if (need_checkpoint_initiation.find(i) !=
        need_checkpoint_initiation.end()) {
      l->needs_fwd_in_bckwd_initiation = true;
    }
    layers[i] = l;
  }
  if (generator->id == 0) {
    std::cout << "type: " << run_type << " ,num passes: " << TOTAL_PASS
              << " ,lines: " << lines
              << " compute scale: " << generator->compute_scale
              << " ,comm scale: " << generator->comm_scale << std::endl;
  }
  inFile.close();
  return true;
}