astra-sim-alibabacloud/astra-sim/system/AstraParamParse.hh (257 lines of code) (raw):
/*
*Copyright (c) 2024, Alibaba Group;
*Licensed under the Apache License, Version 2.0 (the "License");
*you may not use this file except in compliance with the License.
*You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*Unless required by applicable law or agreed to in writing, software
*distributed under the License is distributed on an "AS IS" BASIS,
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*See the License for the specific language governing permissions and
*limitations under the License.
*/
#ifndef __ASTRAPARAMPARSE_HH__
#define __ASTRAPARAMPARSE_HH__
#include <iostream>
#include<sstream>
#include <fstream>
#include <string>
#include <cstdlib>
#include <stdio.h>
#include <unistd.h>
#include <mutex>
#include <chrono>
#include <iomanip>
#include <cstdarg>
#include <vector>
#include <cstdint>
#include "Common.hh"
#define BUSBW_PATH ""
using namespace std;
#include <regex>
enum class ModeType { NONE, ASTRA_SIM, MOCKNCCL, ANALYTICAL };
struct NetWorkParam{
uint32_t node_num;
uint32_t switch_num;
uint32_t link_num;
uint32_t trace_num;
uint32_t nvswitch_num;
uint32_t gpus_per_server;
uint32_t nics_per_server;
uint32_t nvlink_bw;
uint32_t nic_bw;
GPUType gpu_type;
float tp_ar = -1.0f;
float tp_ag = -1.0f;
float tp_rs = -1.0f;
float tp_ata = -1.0f;
float dp_ar = -1.0f;
float dp_ag = -1.0f;
float dp_rs = -1.0f;
float dp_ata = -1.0f;
float ep_ar = -1.0f;
float ep_ag = -1.0f;
float ep_rs = -1.0f;
float ep_ata = -1.0f;
float pp = -1.0f;
float dp_overlap_ratio = 0;
float tp_overlap_ratio = 0;
float ep_overlap_ratio = 0;
float pp_overlap_ratio = 1;
std::vector<int> NVswitchs;
std::vector<std::vector<int>> all_gpus;
int visual = 0;
};
class UserParam {
private:
static UserParam* instance;
static std::mutex mtx;
UserParam() {
thread = 1;
gpus = {};
workload = {};
comm_scale = 1;
mode = ModeType::MOCKNCCL;
}
public:
int thread;
std::vector<int> gpus;
string workload;
string res = "None";
int comm_scale;
ModeType mode;
NetWorkParam net_work_param;
static UserParam* getInstance(){
std::lock_guard<std::mutex> lock(mtx);
if(instance == nullptr){
instance = new UserParam();
}
return instance;
}
void parseYaml(NetWorkParam& params, const std::string& filename) {
std::ifstream file(BUSBW_PATH + filename);
if (!file) {
std::cerr << "Unable to open file: " << filename << std::endl;
exit(-1);
}
std::string line;
std::string currentSection;
std::getline(file, line);
while (std::getline(file, line)) {
// Remove whitespace
line.erase(0, line.find_first_not_of(' '));
line.erase(line.find_last_not_of(' ') + 1);
if (line.empty() || line[0] == '#') continue;
if (line.back() == ':') {
currentSection = line.substr(0, line.size() - 1);
} else {
std::istringstream ss(line);
std::string key, valueStr;
if (std::getline(ss, key, ':') && ss >> valueStr) {
key.erase(key.find_last_not_of(' ') + 1);
// Remove part after comma
auto commaPos = key.find(',');
if (commaPos != std::string::npos) {
key = key.substr(0, commaPos);
}
if (valueStr != "null") {
float value = std::stof(valueStr);
if (currentSection == "TP") {
if (key == "allreduce") params.tp_ar = value;
else if (key == "allgather") params.tp_ag = value;
else if (key == "reducescatter") params.tp_rs = value;
else if (key == "alltoall") params.tp_ata = value;
} else if (currentSection == "DP") {
if (key == "allreduce") params.dp_ar = value;
else if (key == "allgather") params.dp_ag = value;
else if (key == "reducescatter") params.dp_rs = value;
else if (key == "alltoall") params.dp_ata = value;
} else if (currentSection == "EP") {
if (key == "allreduce") params.ep_ar = value;
else if (key == "allgather") params.ep_ag = value;
else if (key == "reducescatter") params.ep_rs = value;
else if (key == "alltoall") params.ep_ata = value;
} else if (currentSection == "PP") {
if (key == "busbw") params.pp = value;
}
}
}
}
}
}
void printHelp() const {
std::cout << " ____ _ _ ___ _ _ _ _ _ \n"
<< "/ ___|(_)_ __ ___ / \\ |_ _| / \\ _ __ __ _| |_ _| |_(_) ___ __ _| |\n"
<< "\\___ \\| | '_ ' _ \\ / _ \\ | |_____ / _ \\ | '_ \\ / _' | | | | | __| |/ __/ _' | |\n"
<< " ___) | | | | | | |/ ___ \\ | |_____/ ___ \\| | | | (_| | | |_| | |_| | (_| (_| | |\n"
<< "|____/|_|_| |_| |_/_/ \\_\\___| /_/ \\_\\_| |_|\\__,_|_|\\__, |\\__|_|\\___\\__,_|_|\n"
<< " |___/ \n";
std::cout << "-w, --workload Workloads, must set" << std::endl;
std::cout << "-g, --gpus Number of GPUs, default 1" << std::endl;
std::cout << "-g_p_s, --gpus-per-server GPUs per server" << std::endl;
std::cout << "-r, --result Output results path, default: ./results/" << std::endl;
std::cout << "-busbw, --bus-bandwidth Bus bandwidth file, must set" << std::endl;
std::cout << "-v, --visual Enable visual output (Default disable)" << std::endl;
std::cout << "-dp_o, --dp-overlap-ratio DP overlap ratio [float: 0.0-1.0] (Default: 0.0)" << std::endl;
std::cout << "-ep_o, --ep-overlap-ratio EP overlap ratio [float: 0.0-1.0] (Default: 0.0)" << std::endl;
std::cout << "-tp_o, --tp-overlap-ratio TP overlap ratio [float: 0.0-1.0] (Default: 0.0)" << std::endl;
std::cout << "-pp_o, --pp-overlap-ratio PP overlap ratio [float: 0.0-1.0] (Default: 1.0)" << std::endl;
}
int printError(const std::string& arg) const {
std::cerr << "Error: Missing value for argument '" << arg << "'." << std::endl;
return 1;
}
int printUnknownOption(const std::string& arg) const {
std::cerr << "Error: Unknown option '" << arg << "'." << std::endl;
return 1;
}
int parseArg(int argc, char *argv[]) {
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
printHelp();
return 1;
} else if (arg == "-w" || arg == "--workload") {
if (++i < argc) this->workload = argv[i];
else return printError(arg);
} else if (arg == "-g" || arg == "--gpus") {
if (++i < argc) this->gpus.push_back(std::stoi(argv[i]));
else return printError(arg);
} else if (arg == "-r" || arg == "--result") {
if (++i < argc) this->res = argv[i];
else return printError(arg);
} else if (arg == "-g_p_s" || arg == "--gpus-per-server") {
if (++i < argc) this->net_work_param.gpus_per_server = std::stoi(argv[i]);
else return printError(arg);
} else if (arg == "-busbw" || arg == "--bus-bandwidth") {
if (++i < argc) parseYaml(this->net_work_param,argv[i]);
else return printError(arg);
} else if (arg == "--dp-overlap-ratio" || arg == "-dp_o") {
if (++i < argc) this->net_work_param.dp_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "--tp-overlap-ratio" || arg == "-tp_o") {
if (++i < argc) this->net_work_param.tp_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "--ep-overlap-ratio" || arg == "-ep_o") {
if (++i < argc) this->net_work_param.ep_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "--pp-overlap-ratio" || arg == "-pp_o") {
if (++i < argc) this->net_work_param.pp_overlap_ratio = std::stof(argv[i]);
else return printError(arg);
} else if (arg == "-v" || arg == "--visual") {
this->net_work_param.visual = 1;
}
else {
return printUnknownOption(arg);
}
}
if (!this->gpus.empty()) {
this->net_work_param.nvswitch_num = this->gpus[0] / this->net_work_param.gpus_per_server;
this->net_work_param.switch_num = 120 + this->net_work_param.gpus_per_server;
this->net_work_param.node_num = this->net_work_param.nvswitch_num + this->net_work_param.switch_num + this->gpus[0];
}
if (this->res == "None" || this->res.back() == '/'){
std::string full_path = this->workload;
std::string model_info = full_path;
size_t last_slash_pos = full_path.find_last_of('/');
if (last_slash_pos != std::string::npos) {
model_info = full_path.substr(last_slash_pos + 1);
}
std::string model_name;
int world_size = 0, tp = 0, pp = 0, ep = 0, gbs = 0, mbs = 0, seq = 0;
size_t world_size_pos = model_info.find("world_size");
if (world_size_pos != std::string::npos) {
model_name = model_info.substr(0, world_size_pos - 1);
}
std::regex param_regex(R"((world_size|tp|pp|ep|gbs|mbs|seq)(\d+))");
std::smatch matches;
std::string params = model_info;
while (std::regex_search(params, matches, param_regex)) {
std::string param_name = matches[1].str();
int param_value = std::stoi(matches[2].str());
if (param_name == "world_size") {
world_size = param_value;
} else if (param_name == "tp") {
tp = param_value;
} else if (param_name == "pp") {
pp = param_value;
} else if (param_name == "ep") {
ep = param_value;
} else if (param_name == "gbs") {
gbs = param_value;
} else if (param_name == "mbs") {
mbs = param_value;
} else if (param_name == "seq") {
seq = param_value;
}
params = matches.suffix().str();
}
int dp = world_size / (tp * pp);
double ga = static_cast<double>(gbs) / (dp * mbs);
std::ostringstream result;
result << model_name << '-'
<< "tp" << tp << '-'
<< "pp" << pp << '-'
<< "dp" << dp << '-'
<< "ga" << static_cast<int>(ga) << '-'
<< "ep" << ep << '-'
<< "NVL" << this->net_work_param.gpus_per_server << '-'
<< "DP" << this->net_work_param.dp_overlap_ratio << '-' ;
if(this->res.back() == '/') {
this->res = this->res + result.str();
}
else{
this->res = result.str();
}
}
return 0;
}
~UserParam(){}
};
#endif // __ASTRAPARAMPARSE_HH__