in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/arguments.py [0:0]
def parse_args(): # pylint: disable=too-many-statements
"""Parse args."""
parser = argparse.ArgumentParser()
# hyperparameters sent by the client are passed as command-line arguments to the script.
### OPTIMIZATION
opt_grp = parser.add_argument_group(
title="optimization", description="arguments for optimization"
)
opt_grp.add_argument(
"--train_batch_size",
type=int,
default=2,
help="batch size per dp rank, for tensor parallelism degree 8 with pipeline parallel degree 1 this means 8*this batch size per node", # pylint: disable=line-too-long
)
opt_grp.add_argument("--max_steps", "--max_training_steps", type=int, default=5000)
opt_grp.add_argument(
"--epochs", type=int, default=3, help="times of iterating over the training dataset"
)
opt_grp.add_argument("--seed", type=int, default=12345)
opt_grp.add_argument("--same_seed", type=int, default=0)
opt_grp.add_argument("--bf16", default=1, type=int, help="automatic mixed precision training")
opt_grp.add_argument("--fp8", default=1, type=int, help="fp8 mixed precision training")
opt_grp.add_argument("--fp8_amax_history_len", default=1024, type=int, help="amax history length")
opt_grp.add_argument("--fp8_amax_compute_algo", default="max", type=str, help="amax computation algorithm: 'max' or 'most_recent'")
opt_grp.add_argument("--grad_clip", default=1.0, type=float, help="gradient clipping")
opt_grp.add_argument("--weight_decay", default=0.2, type=float, help="weight decay")
opt_grp.add_argument(
"--beta1", default=0.9, type=float, help="beta1 parameter for Adam optimizer"
)
opt_grp.add_argument(
"--beta2", default=0.95, type=float, help="beta2 parameter for Adam optimizer"
)
# Learning rate
lr_grp = parser.add_argument_group(
title="lr", description="arguments for learning rate schedule"
)
lr_grp.add_argument("--lr", type=float, default=0.0001, help="Initial learning rate.")
lr_grp.add_argument(
"--lr_decay_style",
type=str,
default="cosine",
choices=["constant", "linear", "cosine", "exponential", "plateau"],
help="Learning rate decay function.",
)
lr_grp.add_argument(
"--lr_decay_iters",
type=int,
default=47683,
help="number of iterations to decay learning rate over," " If None defaults to train iters",
)
lr_grp.add_argument(
"--min_lr",
type=float,
default=1e-05,
help="Minumum value for learning rate. The scheduler" "clip values below this threshold.",
)
lr_grp.add_argument(
"--warmup",
type=float,
default=0.0032,
help="Percentage of total iterations to warmup on "
"(.01 = 1 percent of all training iters).",
)
lr_grp.add_argument(
"--plateau",
type=float,
default=0.0,
help="Percentage of total iterations to keep at max if using plateau lr",
)
### MEMORY USAGE RELATED
mem_grp = parser.add_argument_group(title="memory usage", description="arguments for memory")
mem_grp.add_argument(
"--activation_checkpointing",
type=int,
default=1,
help="enable gradient checkpointing to reduce memory consumption",
)
mem_grp.add_argument("--patch_neox_rope", type=int, default=1)
mem_grp.add_argument("--delayed_param", type=int, default=1)
mem_grp.add_argument(
"--enable_memory_profiling", type=int, default=0, help="Enable memory profile"
)
mem_grp.add_argument(
"--clean_cache",
type=int,
default=0,
help="Clean torch reserved memory at he end of every step",
)
### LOGGING
logging_grp = parser.add_argument_group(
title="logging", description="arguments for logging metrics"
)
logging_grp.add_argument(
"--logging_freq", type=int, default=1, help="number of iterations between logging"
)
logging_grp.add_argument(
"--logging_freq_for_avg",
type=int,
default=50,
help="number of iterations between logging the running avg",
)
logging_grp.add_argument(
"--log_reduced_training_loss",
type=int,
default=0,
help="to log training loss after reducing across all data parallel ranks with logging_freq frequency", # pylint: disable=line-too-long
)
logging_grp.add_argument("--tensorboard_dir", type=str, nargs="+", default=None)
### CHECKPOINTS
ckpt_grp = parser.add_argument_group(title="checkpoints", description="checkpointing arguments")
ckpt_grp.add_argument(
"--num_kept_checkpoints",
nargs="+",
type=int,
default=[2],
help="how many checkpoints to keep before deleting",
)
ckpt_grp.add_argument(
"--checkpoint_freq",
nargs="+",
type=int,
default=[1000],
help="number of iterations between checkpointing",
)
ckpt_grp.add_argument(
"--checkpoint_dir",
nargs="+",
type=str,
default=["/opt/ml/checkpoints"],
help="Saves partial checkpoints (model, optimizer) to this dir, and loads latest checkpoint from this if load_partial is specified.", # pylint: disable=line-too-long
)
ckpt_grp.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Checkpoint folder name to load from",
)
ckpt_grp.add_argument(
"--checkpoint_type", type=str, default="sharded", choices=["local", "sharded", "use_pg_with_util", "async_sharded", "async_local"]
)
ckpt_grp.add_argument(
"--model_dir",
type=str,
default=None,
help="If not passed, saves it to checkpoint_dir/model. Only saved when save_final_model is 1",
)
ckpt_grp.add_argument("--save_final_model", type=int, default=0)
### I/O
input_grp = parser.add_argument_group(title="inputs", description="location for data")
input_grp.add_argument(
"--dataset_type", type=str, default="gpt_jsonl", choices=["gpt_jsonl", "hf"]
)
input_grp.add_argument("--data_num_workers", type=int, default=0)
input_grp.add_argument("--data_type", type=str.lower, default="gpt", choices=["gpt", "bert"])
# dummy dataset
input_grp.add_argument("--use_synthetic_data", type=int, default=0)
# gpt dataset
input_grp.add_argument("--zipped_data", type=int, default=1, help="input data is zipped files")
input_grp.add_argument("--training_dir", type=str, default=os.getenv("SM_CHANNEL_TRAIN"))
input_grp.add_argument("--test_dir", type=str, default=os.getenv("SM_CHANNEL_TEST"))
### MODEL
model_grp = parser.add_argument_group(
title="model", description="arguments to describe model configuration"
)
model_grp.add_argument(
"--hf_pretrained_model_name_or_dir",
type=str,
default=None,
help=(
"For finetuning, pass the pretrained Huggingface model name or path where the model is downloaded. "
"Example: EleutherAI/gpt-neox-20b. or /path/to/downloaded/model. "
"This flag is used for loading both config and weights. "
"When this config is used, flags such as vocab_size, hidden_width etc are ignored in creating the model. "
"For finetuning you need to set this flag even when resuming from a checkpoint. "
),
)
model_grp.add_argument("--max_context_width", type=int, default=2048)
model_grp.add_argument("--vocab_size", type=int, default=50432)
model_grp.add_argument("--hidden_width", type=int, default=768)
model_grp.add_argument("--num_layers", type=int, default=12)
model_grp.add_argument("--num_heads", type=int, default=12)
model_grp.add_argument("--resid_pdrop", type=float, default=0.1)
model_grp.add_argument("--embd_pdrop", type=float, default=0.1)
model_grp.add_argument("--attn_pdrop", type=float, default=0.1)
model_grp.add_argument("--summary_first_pdrop", type=float, default=0.1)
model_grp.add_argument("--initializer_range", type=float, default=0.02)
model_grp.add_argument(
"--model_type", type=str, default="gpt_neox", choices=["gpt_neox", "llama_v2", "gpt2", "mistral", "mixtral", "llama_v3"]
)
model_grp.add_argument("--rotary_pct", type=float, default=0.25)
model_grp.add_argument(
"--rotary_emb_base",
type=int,
default=10000,
help="The base period of the RoPE embeddings.",
)
model_grp.add_argument(
"--rope_scaling_type",
type=str,
choices=["default", "llama3"],
default=None,
help=(
"The sub-variant of RoPE to use. Can be one of ['default', 'llama3'], "
"with 'default' being the original RoPE implementation and 'llama3' "
"being the Llama3.1 RoPE implementation."
),
)
model_grp.add_argument(
"--rope_scaling_factor",
type=float,
default=8.0,
help=(
"Used with all rope types except 'default'. "
"The scaling factor to apply to the RoPE embeddings. "
"In most scaling types, a `factor` of x will enable the model "
"to handle sequences of length x * original maximum pre-trained length."
),
)
model_grp.add_argument(
"--rope_scaling_high_freq_factor",
type=float,
default=4.0,
help=(
"Only used with 'llama3'. "
"Scaling factor applied to high frequency components of the RoPE."
),
)
model_grp.add_argument(
"--rope_scaling_low_freq_factor",
type=float,
default=1.0,
help=(
"Only used with 'llama3'. "
"Scaling factor applied to low frequency components of the RoPE."
),
)
model_grp.add_argument(
"--rope_scaling_original_max_position_embeddings",
type=int,
default=8192,
help=(
"Used with 'dynamic', 'longrope' and 'llama3'. "
"The original max position embeddings used during pretraining."
),
)
model_grp.add_argument("--use_smp_flash_attn", type=int, default=1)
model_grp.add_argument(
"--llama_intermediate_size",
type=int,
default=11008,
help="intermediate_size for Llama v2, a dimension associated with MLP",
)
model_grp.add_argument(
"--intermediate_size",
type=int,
default=14336,
help="A specified intermediate_size, a dimension associated with MLP",
)
model_grp.add_argument(
"--sliding_window",
type=int,
default=None,
help="Sliding window attention window size",
)
model_grp.add_argument(
"--num_key_value_heads",
type=int,
default=None,
help="The number of heads for key and value in GQA",
)
model_grp.add_argument(
"--num_experts_per_tok",
type=int,
default=2,
help="The number of experts to root per-token",
)
model_grp.add_argument(
"--num_local_experts",
type=int,
default=8,
help="Number of experts per Sparse MLP layer",
)
model_grp.add_argument(
"--moe_load_balancing",
type=str,
default="sinkhorn",
choices=["sinkhorn", "balanced", "aux_loss", "none"],
help="Load balancing type of MoE router",
)
model_grp.add_argument(
"--global_token_shuffle",
type=int,
default=0,
help="Global token shuffle for MoE router",
)
model_grp.add_argument(
"--moe_all_to_all_dispatcher",
type=int,
default=1,
help="Use MoE All to All token dispatcher",
)
model_grp.add_argument(
"--use_smp_implementation",
type=int,
default=0,
help="Whether to use SMP optimized implementation of model. "
"All models may not be supported."
"When using tensor_parallel_degree, this is automatically enabled.",
)
model_grp.add_argument("--cp_comm_type", type=str, default="p2p", help="Which context parallelism implementation to use, p2p or all_gather. p2p implementation runs asynchronously, allowing compute overlap", choices=["p2p", "all_gather"])
model_grp.add_argument(
"--moe",
type=int,
default=0,
help="Whether to use MoE implementation of Megatron. "
"All models may not be supported."
)
model_grp.add_argument(
"--moe_fp8_checkpoint_attn",
type=int,
default=1,
help="Checkpoint attention layer for MoE"
)
model_grp.add_argument(
"--moe_fp8_checkpoint_moe",
type=int,
default=1,
help="Checkpoint moe layer for MoE"
)
### FSDP args
fsdp_grp = parser.add_argument_group(
title="fsdp", description="arguments for fully sharded data parallel"
)
fsdp_grp.add_argument("--limit_all_gathers", default=1, type=int)
fsdp_grp.add_argument("--forward_prefetch", default=1, type=int)
fsdp_grp.add_argument(
"--sharding_strategy",
type=str,
default="hybrid_shard",
help="options: no_shard, shard_grad_op, hybrid_shard, _hybrid_shard_zero2, full_shard",
)
fsdp_grp.add_argument(
"--use_orig_params",
default=0,
type=int,
help="This flag needs to be set when you need multiple param groups for optimizer, such as for weight decay",
)
# Note that `shard_degree` might rewrite `sharding_strategy`:
#
# 1. When there is no explicit `shard_degree` or `0`, will fall back to native PyTorch, for all
# `sharding_strategy` cases.
#
# 2. When there is explicit `shard_degree` and it's in `[1, world_size]`:
# - Will rewrite `sharding_strategy` to `HYBRID_SHARD`, when and only when it's not either of
# the two native hybrid strategies, i.e. `{HYBRID_SHARD, _HYBRID_SHARD_ZERO2}`.
#
# - Will use hybrid sharding implementation by SageMaker:
# - 1: Should be equivalent to native PyTorch's `NO_SHARD`.
# - Might have some issues when exporting checkpoints to the disk in native PyTorch.
# - 8: Should be equivalent to native PyTorch's `HYBRID_SHARD`.
# - $world_size: Should be equivalent to native PyTorch's `FULL_SHARD`, though throughput
# might be worse with unnecessary communications.
# - Other values e.g. 2, 4, 16, etc, as long as $world_size is divisible by them:
# - Newly supported sharding implementation by SageMaker.
fsdp_grp.add_argument(
"--backward_fetch_policy",
type=str,
default="backward_pre",
help="options: backward_post, backward_pre",
)
fsdp_grp.add_argument(
"--auto_wrap_policy",
type=str,
default="transformer_auto_wrap_policy",
help="options: size_based_auto_wrap_policy, transformer_auto_wrap_policy",
)
### VALIDATION
validation_grp = parser.add_argument_group(
title="validation", description="arguments for validation"
)
validation_grp.add_argument(
"--validation_freq",
type=int,
default=None,
help="number of iterations to print validation loss",
)
validation_grp.add_argument(
"--validation_batches",
type=int,
default=10,
help="number of batches to estimate validation loss",
)
validation_grp.add_argument(
"--preserve_np_state",
type=int,
default=0,
help="Perserve the numpy random state between validation",
)
validation_grp.add_argument(
"--fast_validation",
type=int,
default=1,
help="Running validation only with the last data file for faster speed",
)
validation_grp.add_argument("--val_batch_size", type=int, default=4)
### OTHERS
parser.add_argument(
"--distributed_backend",
type=str,
default="smddp",
choices=["smddp", "nccl"],
help="Distributed backend to use for collectives",
)
parser.add_argument("--nccl_test_log", type=str, default="")
parser.add_argument("--profile_nsys", type=int, default=0)
parser.add_argument("--framework", type=str, default="fsdp")
return parser.parse_known_args()