in orttraining/orttraining/models/bert/main.cc [86:532]
Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParameters& ort_params) {
cxxopts::Options options("BERT Training", "Main Program to train BERT");
// clang-format off
options
.add_options()
("model_name", "model to be trained", cxxopts::value<std::string>())
("train_data_dir", "Input ONNX example files (can be a glob or comma separated).",
cxxopts::value<std::string>()->default_value("bert_data/128/books_wiki_en_corpus/train"))
("test_data_dir", "Input ONNX example files (can be a glob or comma separated).",
cxxopts::value<std::string>()->default_value("bert_data/128/books_wiki_en_corpus/test"))
("train_data_dir_phase2", "Input ONNX example files (can be a glob or comma separated).",
cxxopts::value<std::string>()->default_value(""))
("test_data_dir_phase2", "Input ONNX example files (can be a glob or comma separated).",
cxxopts::value<std::string>()->default_value(""))
("output_dir", "The output directory where the trained model files will be written.",
cxxopts::value<std::string>()->default_value(""))
("perf_output_dir", "The output directory where the trained perf metrics files will be written.",
cxxopts::value<std::string>()->default_value(""))
("checkpoints_dir", "The output directory where the checkpoint files will be written.",
cxxopts::value<std::string>()->default_value(""))
("checkpoint_to_load_path",
"The path to the checkpoint to load. If not provided, the latest "
"checkpoint in checkpoints_dir, if any, is used.",
cxxopts::value<std::string>()->default_value(""))
("log_dir", "The directory to write tensorboard events.",
cxxopts::value<std::string>()->default_value(""))
("convergence_test_output_file", "The convergence test output file path.",
cxxopts::value<std::string>()->default_value(""))
("train_batch_size", "Total batch size for training.", cxxopts::value<int>())
("train_batch_size_phase2", "Total batch size for training.", cxxopts::value<int>()->default_value("1"))
("eval_batch_size", "Total batch size for eval.", cxxopts::value<int>())
("learning_rate", "The initial learning rate for the optimizer.", cxxopts::value<float>()->default_value("5e-5"))
("learning_rate_phase2", "The initial learning rate for the optimizer.", cxxopts::value<float>()->default_value("4e-3"))
("num_train_steps", "Total number of training steps to perform.", cxxopts::value<int>()->default_value("100000"))
("num_train_steps_phase2", "Total number of training steps to perform.", cxxopts::value<int>()->default_value("1563"))
("warmup_ratio", "Fraction of training steps for learning rate warmup.", cxxopts::value<float>()->default_value("0"))
("warmup_ratio_phase2", "Fraction of training steps for learning rate warmup.", cxxopts::value<float>()->default_value("0.128"))
("warmup_mode", "Warmup mode, one of [None|Cosine|Constant|Linear|Poly], defaults None.",
cxxopts::value<std::string>()->default_value("None"))
("do_eval", "Whether to run eval on the dev set.", cxxopts::value<bool>()->default_value("false"))
("evaluation_period",
"How many training steps to make before making an evaluation.",
cxxopts::value<size_t>()->default_value("100"))
("display_loss_steps", "How often to dump loss into tensorboard", cxxopts::value<size_t>()->default_value("10"))
("gradient_accumulation_steps", "The number of gradient accumulation steps before performing a backward/update pass.",
cxxopts::value<int>()->default_value("1"))
("checkpoint_period", "How many weight-update steps to run before saving a model checkpoint.", cxxopts::value<size_t>()->default_value("1000"))
("max_num_checkpoints", "Maximum number of checkpoint files to maintain.",
cxxopts::value<size_t>()->default_value("10"))
("gradient_accumulation_steps_phase2", "The number of gradient accumulation steps before performing a backward/update pass in phase 2.",
cxxopts::value<int>()->default_value("1"))
("iterations_per_loop", "How many steps to make in each estimator call.", cxxopts::value<int>()->default_value("1000"))
("max_eval_steps", "Maximum number of eval steps.", cxxopts::value<int>()->default_value("100"))
("seed", "Random seed.", cxxopts::value<int64_t>()->default_value("-1"))
("use_deterministic_compute", "Whether to enable deterministic compute.", cxxopts::value<bool>()->default_value("false"))
("use_mixed_precision", "Whether to use a mix of fp32 and fp16 arithmetic on GPU.", cxxopts::value<bool>()->default_value("false"))
("use_bfloat16", "Whether to use BFloat16 arithmetic on GPU.", cxxopts::value<bool>()->default_value("false"))
("enable_adasum", "Whether to use Adasum for allreduction.", cxxopts::value<bool>()->default_value("false"))
("allreduce_in_fp16", "Whether to do AllReduce in fp16. If false, AllReduce will be done in fp32", cxxopts::value<bool>()->default_value("true"))
("loss_scale", "Loss scaling, positive power of 2 values can improve fp16 convergence. "
"Set it 0 to uses dynamic scaling; Other none-zero value will used as static scale",
cxxopts::value<float>()->default_value("0.0"))
("use_fp16_moments", "Whether to use fp16 version of moments.", cxxopts::value<bool>()->default_value("false"))
("use_fp16_initializer", "FP16 weights will be created. Otherwise, cast nodes will be inserted for converting weights from FP32 to FP16",
cxxopts::value<bool>()->default_value("true"))
("use_nccl", "Whether to use NCCL for distributed training.", cxxopts::value<bool>()->default_value("false"))
("use_profiler", "Collect runtime profile data during this training run.", cxxopts::value<bool>()->default_value("false"))
("use_gist", "Whether to use GIST encoding/decoding.")
("gist_op", "Opearator type(s) to which GIST is applied.", cxxopts::value<int>()->default_value("0"))
("gist_compr", "Compression type used for GIST", cxxopts::value<std::string>()->default_value("GistPack8"))
("max_profile_records", "Maximum number of runtime profile data records to collect. 0 means use the default value.",
cxxopts::value<size_t>()->default_value("0"))
("mode", "mode for running, can be one of [train|perf]", cxxopts::value<std::string>()->default_value("train"))
("histogram", "Tensor(s) to display a histogram on tensorboard (e.g. '417,3347,417_grad,3347_grad' for bert-large or '81,449,81_grad,449_grad' for bert-tiny)",
cxxopts::value<std::vector<std::string>>()->default_value({}))
("norm", "Tensor(s) to display their L2-norm values on tensorboard (e.g. '417,3347,417_grad,3347_grad' for bert-large or '81,449,81_grad,449_grad' for bert-tiny)",
cxxopts::value<std::vector<std::string>>()->default_value({}))
("dump_convergence_metrics", "specify if tensorboard should include convergence metrics such as gradient norm.",
cxxopts::value<bool>()->default_value("false"))
("max_seq_length",
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded. Must match data generation.", cxxopts::value<int>()->default_value("512"))
("max_predictions_per_seq",
"Maximum number of masked LM predictions per sequence. "
"Must match data generation.", cxxopts::value<int>()->default_value("80"))
("optimizer", "Adam or Lamb", cxxopts::value<std::string>()->default_value("Adam"))
("deepspeed_zero_stage", "Controls whether to partition state using the DeepSpeed ZeRO technique. "
"Stages 0 (disabled) and 1 (optimizer state partitioning) are supported.",
cxxopts::value<int>()->default_value("0"))
("alpha", "Adam/Lamb alpha parameter", cxxopts::value<float>()->default_value("0.9"))
("beta", "Adam/Lamb beta parameter", cxxopts::value<float>()->default_value("0.999"))
("lambda", "Adam/Lamb lambda parameter", cxxopts::value<float>()->default_value("0.01"))
("epsilon", "Adam/Lamb epsilon parameter", cxxopts::value<float>()->default_value("1e-6"))
("do_bias_correction",
"A flag controls if Adam/Lamb should do bias correction. "
"Default is false, which means no bias correction. "
"Use true to enable bias correction.",
cxxopts::value<bool>()->default_value("false"))
("weight_decay_mode",
"Chooses the weight decay mode for Adam optimizer "
"Default is 0, which does weight decay before updating weight. "
"Use 1 to do weight decay after updating weight.",
cxxopts::value<int64_t>()->default_value("0"))
("ratio_min", "Lamb min ratio parameter", cxxopts::value<float>()->default_value("0.05"))
("ratio_max", "Lamb max ratio parameter", cxxopts::value<float>()->default_value("5.0"))
("gpu_mem_limit_in_gb", "Max cuda memory ort can use, in GB", cxxopts::value<float>()->default_value("-1.0"))
("data_parallel_size", "Data parallel group size.", cxxopts::value<int>()->default_value("1"))
("horizontal_parallel_size", "Horizontal model parallel group size.", cxxopts::value<int>()->default_value("1"))
("pipeline_parallel_size", "Number of pipeline stages.", cxxopts::value<int>()->default_value("1"))
("pipeline_stage_paths", "Specify the forward ONNX files for pipeline evaluation.", cxxopts::value<std::vector<std::string>>()->default_value(""))
("cut_group_info", "Specify the cutting info for graph partition (pipeline only). An example of a cut_group_info of "
"size two is: 1393:407-1463/1585/1707,2369:407-2439/2561/2683. Here, the cut info is split by ',', with the first "
"cut_info equal to 1393:407-1463/1585/1707, and second cut_info equal to 2369:407-2439/2561/2683. Each CutEdge is "
"seperated by ':'. If consumer nodes need to be specified, specify them after producer node with a '-' delimiter and "
"separate each consumer node with a '/'. ", cxxopts::value<std::vector<std::string>>()->default_value(""))
("enable_grad_norm_clip", "Specify whether to enable gradient clipping for optimizers.",
cxxopts::value<bool>()->default_value("true"))
("enable_gelu_approximation", "Specify whether to enable GELU approximation.",
cxxopts::value<bool>()->default_value("true"))
("attn_dropout_recompute", "Enable checkpointing of attention dropout to save memory.",
cxxopts::value<bool>()->default_value("false"))
("gelu_recompute", "Enable checkpointing of Gelu activation output to save memory.",
cxxopts::value<bool>()->default_value("false"))
("transformer_layer_recompute", "Enable checkpointing of transformer layer output to save memory.",
cxxopts::value<bool>()->default_value("false"))
("number_recompute_layers", "Number of layers to apply recompute.",
cxxopts::value<int>()->default_value("0"))
("use_memory_efficient_gradient", "Specify whether to use memory aware gradient builder.)",
cxxopts::value<bool>()->default_value("false"))
("debug_break", "Specify whether to break at app start, useful for multi-gpu debugging.",
cxxopts::value<bool>()->default_value("false"));
options
.add_options("ORT configuration")
("ort_log_severity", "ORT minimum logging severity (see onnxruntime::logging::Severity values)",
cxxopts::value<int>()->default_value("2"/*logging::Severity::kWARNING*/))
("ort_vlog_level", "ORT maximum VLOG level (verbose debug logging)",
cxxopts::value<int>()->default_value("-1"));
// clang-format on
try {
auto flags = options.parse(argc, argv);
params.model_name = flags["model_name"].as<std::string>();
params.debug_break = flags["debug_break"].as<bool>();
float lr = flags["learning_rate"].as<float>();
if (lr > 1.f || lr < 0.f) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "learning_rate is not in valid range [0.0, 1.0]");
}
params.lr_params.initial_lr = lr;
float lr_phase2 = flags["learning_rate_phase2"].as<float>();
if (lr_phase2 > 1.f || lr_phase2 < 0.f) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "learning_rate_phase2 is not in valid range [0.0, 1.0]");
}
params.initial_lr_phase2 = lr_phase2;
float ratio = flags["warmup_ratio"].as<float>();
if (ratio > 1.f || ratio < 0.f) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "warmup_ratio is not in valid range [0.0, 1.0]");
}
params.lr_params.warmup_ratio = ratio;
params.gpu_mem_limit_in_gb = flags["gpu_mem_limit_in_gb"].as<float>();
float ratio_phase2 = flags["warmup_ratio_phase2"].as<float>();
if (ratio_phase2 > 1.f || ratio_phase2 < 0.f) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "warmup_ratio_phase2 is not in valid range [0.0, 1.0]");
}
params.warmup_ratio_phase2 = ratio_phase2;
params.num_train_steps = flags["num_train_steps"].as<int>();
params.num_train_steps_phase2 = flags["num_train_steps_phase2"].as<int>();
params.batch_size = flags["train_batch_size"].as<int>();
params.gist_config.op_type = flags["gist_op"].as<int>();
params.gist_config.compr_type = flags["gist_compr"].as<std::string>();
if (flags.count("eval_batch_size")) {
params.eval_batch_size = flags["eval_batch_size"].as<int>();
} else {
params.eval_batch_size = params.batch_size;
}
params.batch_size_phase2 = flags["train_batch_size_phase2"].as<int>();
params.max_sequence_length = flags["max_seq_length"].as<int>();
params.max_predictions_per_sequence = flags["max_predictions_per_seq"].as<int>();
params.gradient_accumulation_steps = flags["gradient_accumulation_steps"].as<int>();
if (params.gradient_accumulation_steps < 1) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid gradient_accumulation_steps parameter: should be >= 1");
}
params.gradient_accumulation_steps_phase2 = flags["gradient_accumulation_steps_phase2"].as<int>();
if (params.gradient_accumulation_steps_phase2 < 1) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid gradient_accumulation_steps_phase2 parameter: should be >= 1");
}
params.do_eval = flags["do_eval"].as<bool>();
params.evaluation_period = flags["evaluation_period"].as<size_t>();
params.display_loss_steps = flags["display_loss_steps"].as<size_t>();
params.checkpoint_period = flags["checkpoint_period"].as<size_t>();
params.max_num_checkpoints = flags["max_num_checkpoints"].as<size_t>();
params.use_nccl = flags["use_nccl"].as<bool>();
params.enable_adasum = flags["enable_adasum"].as<bool>();
params.use_profiler = flags.count("use_profiler") > 0;
ort_params.max_num_profiling_events = flags["max_profile_records"].as<size_t>();
params.train_data_dir = ToPathString(flags["train_data_dir"].as<std::string>());
params.test_data_dir = ToPathString(flags["test_data_dir"].as<std::string>());
params.log_dir = ToPathString(flags["log_dir"].as<std::string>());
params.train_data_dir_phase2 = ToPathString(flags["train_data_dir_phase2"].as<std::string>());
params.test_data_dir_phase2 = ToPathString(flags["test_data_dir_phase2"].as<std::string>());
params.convergence_test_output_file = ToPathString(flags["convergence_test_output_file"].as<std::string>());
params.output_dir = ToPathString(flags["output_dir"].as<std::string>());
if (params.output_dir.empty()) {
printf("No output directory specified. Trained model files will not be saved.\n");
}
params.perf_output_dir = ToPathString(flags["perf_output_dir"].as<std::string>());
if (params.perf_output_dir.empty()) {
printf("No perf output directory specified. Trained perf metrics will not be saved.\n");
}
params.checkpoints_dir = ToPathString(flags["checkpoints_dir"].as<std::string>());
if (params.checkpoints_dir.empty()) {
printf("No checkpoints directory specified. Checkpoint files will not be saved.\n");
}
params.checkpoint_to_load_path = ToPathString(flags["checkpoint_to_load_path"].as<std::string>());
params.histogram_names = flags["histogram"].as<std::vector<std::string>>();
params.norm_names = flags["norm"].as<std::vector<std::string>>();
params.dump_convergence_metrics = flags["dump_convergence_metrics"].as<bool>();
std::string mode = flags["mode"].as<std::string>();
if (mode == "perf" || mode == "train") {
params.is_perf_test = mode == "perf";
} else {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Incorrect command line for mode: it must be one of [perf|train]");
}
params.use_mixed_precision = flags["use_mixed_precision"].as<bool>();
params.use_bfloat16 = flags["use_bfloat16"].as<bool>();
params.allreduce_in_mixed_precision_type = flags["allreduce_in_fp16"].as<bool>() && params.use_mixed_precision;
if (params.use_mixed_precision) {
printf("Mixed precision training is enabled.\n");
}
if (params.allreduce_in_mixed_precision_type) {
printf("Performing AllReduce in mixed precision type \n");
} else {
printf("Performing AllReduce in fp32 \n");
}
{
const float loss_scale = flags["loss_scale"].as<float>();
if (loss_scale < 0.0f) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Loss scale should be >= 0.");
}
params.loss_scale = loss_scale;
if (params.use_mixed_precision) {
if (params.loss_scale == 0.0) {
printf("Using Dynamic loss scale.\n");
} else {
printf("Mixed precision loss scale is: %f\n", params.loss_scale);
}
}
}
params.use_mixed_precision_moments = flags["use_fp16_moments"].as<bool>();
if (params.use_mixed_precision_moments) {
printf("Using mixed precision version of moments.\n");
}
params.use_mixed_precision_initializer = flags["use_fp16_initializer"].as<bool>();
if (params.use_mixed_precision && params.use_mixed_precision_initializer) {
printf("Mixed precision initializer is enabled.\n");
}
std::string warmup_mode = flags["warmup_mode"].as<std::string>();
if (warmup_mode == LRSchedule_NoWarmup ||
warmup_mode == LRSchedule_Cosine ||
warmup_mode == LRSchedule_Constant ||
warmup_mode == LRSchedule_Linear ||
warmup_mode == LRSchedule_Poly) {
params.lr_params.warmup_mode = warmup_mode;
printf("Using learning rate warmup mode: %s \n", warmup_mode.c_str());
} else {
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
"Incorrect warmup_mode: it must be one of [None|Cosine|Constant|Linear|Poly]");
}
std::string optimizer_name = flags["optimizer"].as<std::string>();
if (optimizer_name == "adam" || optimizer_name == "Adam") {
params.training_optimizer_name = "AdamOptimizer";
} else if (optimizer_name == "lamb" || optimizer_name == "Lamb") {
params.training_optimizer_name = "LambOptimizer";
} else {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Incorrect optimizer type: it must be one of [Adam|Lamb]");
}
params.deepspeed_zero = ZeROConfig(flags["deepspeed_zero_stage"].as<int>());
params.enable_grad_norm_clip = flags["enable_grad_norm_clip"].as<bool>();
params.use_gist = flags.count("use_gist") > 0;
float alpha = flags["alpha"].as<float>();
float beta = flags["beta"].as<float>();
float lambda = flags["lambda"].as<float>();
float epsilon = flags["epsilon"].as<float>();
int64_t weight_decay_mode = flags["weight_decay_mode"].as<int64_t>();
float ratio_min = flags["ratio_min"].as<float>();
float ratio_max = flags["ratio_max"].as<float>();
ORT_RETURN_IF_NOT(alpha >= 0.f && alpha <= 1.f, "alpha is not in valid range [0.0, 1.0]");
ORT_RETURN_IF_NOT(beta >= 0.f && beta <= 1.f, "alpha is not in valid range [0.0, 1.0]");
ORT_RETURN_IF_NOT(weight_decay_mode == 0 || weight_decay_mode == 1, "Only 0 and 1 are supported for weight decay mode.");
ORT_RETURN_IF_NOT(epsilon >= 0.f, "epsilon should be non-negative.");
ORT_RETURN_IF_NOT(epsilon >= 0.f, "epsilon should be non-negative.");
ORT_RETURN_IF_NOT(ratio_min >= 0.f, "ratio_min should be non-negative.");
ORT_RETURN_IF_NOT(ratio_max >= 0.f, "ratio_max should be non-negative.");
ORT_RETURN_IF_NOT(ratio_max >= ratio_min, "ratio_max should be greater than or equal to ratio_min.");
std::vector<std::string> no_decay{"bias", "gamma", "beta", "LayerNorm"};
bool do_bias_correction = flags["do_bias_correction"].as<bool>();
// Optimizer's float attributes.
params.optimizer_attributes = [=](const std::string& weight) {
// Set lambda attribute to zero if we don't want decay on this weight.
bool zero_lambda = std::any_of(no_decay.begin(), no_decay.end(), [&](const std::string& name) {
return weight.find(name) != std::string::npos;
});
return std::unordered_map<std::string, float>{
{"alpha", alpha},
{"beta", beta},
{"lambda", zero_lambda ? 0.f : lambda},
{"epsilon", epsilon},
{"ratio_min", ratio_min},
{"ratio_max", ratio_max}};
};
// Optimizer's int attributes.
params.optimizer_int_attributes = [=](const std::string& /*weight*/) {
return std::unordered_map<std::string, int64_t>{
{"do_bias_correction", do_bias_correction ? static_cast<int64_t>(1) : static_cast<int64_t>(0)},
{"weight_decay_mode", weight_decay_mode}};
};
params.data_parallel_size = flags["data_parallel_size"].as<int>();
params.horizontal_parallel_size = flags["horizontal_parallel_size"].as<int>();
ORT_RETURN_IF_NOT(params.data_parallel_size > 0, "data_parallel_size must > 0");
ORT_RETURN_IF_NOT(params.horizontal_parallel_size > 0, "horizontal_parallel_size must > 0");
// pipeline_parallel_size controls the number of pipeline's stages.
// pipeline_parallel_size=1 means no model partition, which means all processes run
// the same model. We only partition model when pipeline_parallel_size > 1.
params.pipeline_parallel_size = flags["pipeline_parallel_size"].as<int>();
ORT_RETURN_IF_NOT(params.pipeline_parallel_size > 0, "pipeline_parallel_size must > 0");
// If user provides partitioned model files, the number of files should match the number of
// processes. The i-th file should correspond to the i-th process' pipeline stage.
// All files only store forward pass with a Recv and a Send.
// Backward pass and optimizer nodes are implicitly generated by ORT.
params.pipeline_stage_paths = flags["pipeline_stage_paths"].as<std::vector<std::string>>();
// If user doesn't provide partitioned model files, a cut list should be provided for ORT to do partition
// online. If the pipeline contains n stages, the cut list should be of length (n-1), in order to cut the
// graph into n partitions.
if (params.pipeline_parallel_size > 1 && params.pipeline_stage_paths.empty()) {
auto cut_info_groups = flags["cut_group_info"].as<std::vector<std::string>>();
ORT_RETURN_IF_NOT(static_cast<int>(cut_info_groups.size() + 1) == params.pipeline_parallel_size,
"cut_info length plus one must match pipeline parallel size");
auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) {
std::vector<std::string> result;
size_t pos = 0;
std::string token;
while ((pos = input_str.find(delimiter)) != std::string::npos) {
token = input_str.substr(0, pos);
result.emplace_back(token);
input_str.erase(0, pos + delimiter.length());
}
// push the last split of substring into result.
result.emplace_back(input_str);
return result;
};
auto process_cut_info = [&](std::string& cut_info_string) {
TrainingSession::TrainingConfiguration::CutInfo cut_info;
const std::string edge_delimiter = ":";
const std::string consumer_delimiter = "/";
const std::string producer_consumer_delimiter = "-";
auto cut_edges = process_with_delimiter(cut_info_string, edge_delimiter);
for (auto& cut_edge : cut_edges) {
auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter);
if (process_edge.size() == 1) {
TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]};
cut_info.emplace_back(edge);
} else {
ORT_ENFORCE(process_edge.size() == 2);
auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter);
TrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list};
cut_info.emplace_back(edge);
}
}
return cut_info;
};
for (auto& cut_info : cut_info_groups) {
TrainingSession::TrainingConfiguration::CutInfo cut = process_cut_info(cut_info);
params.pipeline_partition_cut_list.emplace_back(cut);
}
}
int64_t seed = flags["seed"].as<int64_t>();
if (params.horizontal_parallel_size > 1 && seed <= 0) {
seed = 8211; // Megatron needs a random seed.
}
if (seed > 0) {
utils::SetRandomSeed(seed);
std::cout << "Random seed is set to: " << seed << std::endl;
}
session_options.use_deterministic_compute = flags["use_deterministic_compute"].as<bool>();
params.enable_gelu_approximation = flags["enable_gelu_approximation"].as<bool>();
params.attn_dropout_recompute = flags["attn_dropout_recompute"].as<bool>();
params.gelu_recompute = flags["gelu_recompute"].as<bool>();
params.transformer_layer_recompute = flags["transformer_layer_recompute"].as<bool>();
params.number_recompute_layers = flags["number_recompute_layers"].as<int>();
ort_params.log_severity = static_cast<logging::Severity>(flags["ort_log_severity"].as<int>());
ORT_RETURN_IF_NOT(
logging::Severity::kVERBOSE <= ort_params.log_severity &&
ort_params.log_severity <= logging::Severity::kFATAL,
"Log severity must be in the range [", static_cast<int>(logging::Severity::kVERBOSE),
", ", static_cast<int>(logging::Severity::kFATAL), "].");
ort_params.vlog_level = flags["ort_vlog_level"].as<int>();
params.use_memory_efficient_gradient = flags["use_memory_efficient_gradient"].as<bool>();
} catch (const exception& e) {
const std::string msg = "Failed to parse the command line arguments";
cerr << msg << ": " << e.what() << "\n"
<< options.help() << "\n";
return Status(ONNXRUNTIME, INVALID_ARGUMENT, msg);
}
return Status::OK();
}