in astra-sim-alibabacloud/astra-sim/system/AstraParamParse.hh [188:296]
int parseArg(int argc, char *argv[]) {
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
printHelp();
return 1;
} else if (arg == "-w" || arg == "--workload") {
if (++i < argc) this->workload = argv[i];
else return printError(arg);
} else if (arg == "-g" || arg == "--gpus") {
if (++i < argc) this->gpus.push_back(std::stoi(argv[i]));
else return printError(arg);
} else if (arg == "-r" || arg == "--result") {
if (++i < argc) this->res = argv[i];
else return printError(arg);
} else if (arg == "-g_p_s" || arg == "--gpus-per-server") {
if (++i < argc) this->net_work_param.gpus_per_server = std::stoi(argv[i]);
else return printError(arg);
} else if (arg == "-busbw" || arg == "--bus-bandwidth") {
if (++i < argc) parseYaml(this->net_work_param,argv[i]);
else return printError(arg);
} else if (arg == "--dp-overlap-ratio" || arg == "-dp_o") {
if (++i < argc) this->net_work_param.dp_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "--tp-overlap-ratio" || arg == "-tp_o") {
if (++i < argc) this->net_work_param.tp_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "--ep-overlap-ratio" || arg == "-ep_o") {
if (++i < argc) this->net_work_param.ep_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "--pp-overlap-ratio" || arg == "-pp_o") {
if (++i < argc) this->net_work_param.pp_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "-v" || arg == "--visual") {
this->net_work_param.visual = 1;
}
else {
return printUnknownOption(arg);
}
}
if (!this->gpus.empty()) {
this->net_work_param.nvswitch_num = this->gpus[0] / this->net_work_param.gpus_per_server;
this->net_work_param.switch_num = 120 + this->net_work_param.gpus_per_server;
this->net_work_param.node_num = this->net_work_param.nvswitch_num + this->net_work_param.switch_num + this->gpus[0];
}
if (this->res == "None" || this->res.back() == '/'){
std::string full_path = this->workload;
std::string model_info = full_path;
size_t last_slash_pos = full_path.find_last_of('/');
if (last_slash_pos != std::string::npos) {
model_info = full_path.substr(last_slash_pos + 1);
}
std::string model_name;
int world_size = 0, tp = 0, pp = 0, ep = 0, gbs = 0, mbs = 0, seq = 0;
size_t world_size_pos = model_info.find("world_size");
if (world_size_pos != std::string::npos) {
model_name = model_info.substr(0, world_size_pos - 1);
}
std::regex param_regex(R"((world_size|tp|pp|ep|gbs|mbs|seq)(\d+))");
std::smatch matches;
std::string params = model_info;
while (std::regex_search(params, matches, param_regex)) {
std::string param_name = matches[1].str();
int param_value = std::stoi(matches[2].str());
if (param_name == "world_size") {
world_size = param_value;
} else if (param_name == "tp") {
tp = param_value;
} else if (param_name == "pp") {
pp = param_value;
} else if (param_name == "ep") {
ep = param_value;
} else if (param_name == "gbs") {
gbs = param_value;
} else if (param_name == "mbs") {
mbs = param_value;
} else if (param_name == "seq") {
seq = param_value;
}
params = matches.suffix().str();
}
int dp = world_size / (tp * pp);
double ga = static_cast<double>(gbs) / (dp * mbs);
std::ostringstream result;
result << model_name << '-'
<< "tp" << tp << '-'
<< "pp" << pp << '-'
<< "dp" << dp << '-'
<< "ga" << static_cast<int>(ga) << '-'
<< "ep" << ep << '-'
<< "NVL" << this->net_work_param.gpus_per_server << '-'
<< "DP" << this->net_work_param.dp_overlap_ratio << '-' ;
if(this->res.back() == '/') {
this->res = this->res + result.str();
}
else{
this->res = result.str();
}
}
return 0;
}