int parseArg()

in astra-sim-alibabacloud/astra-sim/system/AstraParamParse.hh [188:296]


    int parseArg(int argc, char *argv[]) {
        for (int i = 1; i < argc; ++i) {
            std::string arg = argv[i];
            if (arg == "-h" || arg == "--help") {
                printHelp();
                return 1;
            } else if (arg == "-w" || arg == "--workload") {
                if (++i < argc) this->workload = argv[i];
                else return printError(arg);
            } else if (arg == "-g" || arg == "--gpus") {
                if (++i < argc) this->gpus.push_back(std::stoi(argv[i]));
                else return printError(arg);
            } else if (arg == "-r" || arg == "--result") {
                if (++i < argc) this->res = argv[i];
                else return printError(arg);
            } else if (arg == "-g_p_s" || arg == "--gpus-per-server") {
                if (++i < argc) this->net_work_param.gpus_per_server = std::stoi(argv[i]);
                else return printError(arg);
            } else if (arg == "-busbw" || arg == "--bus-bandwidth") {
                if (++i < argc) parseYaml(this->net_work_param,argv[i]);
                else return printError(arg);
            } else if (arg == "--dp-overlap-ratio" || arg == "-dp_o") {
                if (++i < argc) this->net_work_param.dp_overlap_ratio = std::stof(argv[i]);
                else return printError(arg);
            } else if (arg == "--tp-overlap-ratio" || arg == "-tp_o") {
                if (++i < argc) this->net_work_param.tp_overlap_ratio = std::stof(argv[i]);
                else return printError(arg);
            } else if (arg == "--ep-overlap-ratio" || arg == "-ep_o") {
                if (++i < argc) this->net_work_param.ep_overlap_ratio = std::stof(argv[i]);
                else return printError(arg);
            } else if (arg == "--pp-overlap-ratio" || arg == "-pp_o") {
                if (++i < argc) this->net_work_param.pp_overlap_ratio = std::stof(argv[i]);
                else return printError(arg);
            } else if (arg == "-v" || arg == "--visual") {
                this->net_work_param.visual = 1;
            }
            else {
                return printUnknownOption(arg);
            }
        }

        if (!this->gpus.empty()) {
            this->net_work_param.nvswitch_num = this->gpus[0] / this->net_work_param.gpus_per_server;
            this->net_work_param.switch_num = 120 + this->net_work_param.gpus_per_server;
            this->net_work_param.node_num = this->net_work_param.nvswitch_num + this->net_work_param.switch_num + this->gpus[0];
        }
        if (this->res == "None" || this->res.back() == '/'){
            std::string full_path = this->workload;
            std::string model_info = full_path;
            size_t last_slash_pos = full_path.find_last_of('/');
            if (last_slash_pos != std::string::npos) {
                model_info = full_path.substr(last_slash_pos + 1); 
            }
            std::string model_name; 
            int world_size = 0, tp = 0, pp = 0, ep = 0, gbs = 0, mbs = 0, seq = 0;

            size_t world_size_pos = model_info.find("world_size");
            if (world_size_pos != std::string::npos) {
                model_name = model_info.substr(0, world_size_pos - 1); 
            }

            std::regex param_regex(R"((world_size|tp|pp|ep|gbs|mbs|seq)(\d+))");
            std::smatch matches;

            std::string params = model_info; 
            while (std::regex_search(params, matches, param_regex)) {
                std::string param_name = matches[1].str();
                int param_value = std::stoi(matches[2].str());

                if (param_name == "world_size") {
                    world_size = param_value;
                } else if (param_name == "tp") {
                    tp = param_value;
                } else if (param_name == "pp") {
                    pp = param_value;
                } else if (param_name == "ep") {
                    ep = param_value;
                } else if (param_name == "gbs") {
                    gbs = param_value;
                } else if (param_name == "mbs") {
                    mbs = param_value;
                } else if (param_name == "seq") {
                    seq = param_value;
                }
                params = matches.suffix().str();
            }

            int dp = world_size / (tp * pp); 
            double ga = static_cast<double>(gbs) / (dp * mbs); 

            std::ostringstream result;
            result << model_name << '-' 
                << "tp" << tp << '-'
                << "pp" << pp << '-'
                << "dp" << dp << '-'
                << "ga" << static_cast<int>(ga) << '-'
                << "ep" << ep << '-'
                << "NVL" << this->net_work_param.gpus_per_server << '-'
                << "DP" << this->net_work_param.dp_overlap_ratio << '-' ;
            if(this->res.back() == '/') {
                this->res = this->res + result.str();
            }
            else{
                this->res = result.str();
            }
            
        }
        return 0;
    }