in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/train_utils.py [0:0]
def get_model_config(args):
"""Get model config."""
if "gpt_neox" in args.model_type:
from transformers import GPTNeoXConfig
model_config = GPTNeoXConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_width,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
hidden_act="gelu",
intermediate_size=4 * args.hidden_width,
rotary_pct=args.rotary_pct,
rotary_emb_base=args.rotary_emb_base,
max_position_embeddings=args.max_context_width,
layer_norm_eps=1e-05,
initializer_range=args.initializer_range,
use_cache=False,
tie_word_embeddings=False,
use_parallel_residual=True,
attention_dropout=0.0,
hidden_dropout=0.0,
)
elif "gpt2" in args.model_type:
from transformers import GPT2Config
model_config = GPT2Config(
vocab_size=args.vocab_size,
n_positions=args.max_context_width,
n_embd=args.hidden_width,
n_layer=args.num_layers,
n_head=args.num_heads,
n_inner=None,
activation_function="gelu_new",
resid_pdrop=args.resid_pdrop,
embd_pdrop=args.embd_pdrop,
attn_pdrop=args.attn_pdrop,
layer_norm_epsilon=1e-05,
initializer_range=args.initializer_range,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=args.summary_first_pdrop,
use_cache=False,
bos_token_id=50256,
eos_token_id=50256,
return_dict=True,
)
elif "llama_v2" in args.model_type:
from transformers import LlamaConfig
model_config = LlamaConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_width,
intermediate_size=args.llama_intermediate_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_key_value_heads,
hidden_act="silu",
max_position_embeddings=args.max_context_width,
initializer_range=args.initializer_range,
rms_norm_eps=1e-5,
use_cache=False,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=args.rotary_emb_base,
)
elif "llama_v3" in args.model_type:
from transformers import LlamaConfig
rope_scaling = None
if args.rope_scaling_type == "llama3":
if pversion.parse(transformers.__version__) < pversion.parse("4.44.2"):
raise ValueError(
"Rope scaling type 'llama3' is only supported for transformers >= 4.44.2. "
"Please upgrade transformers or pass None to use the original RoPE implementation."
)
rope_scaling = {
"rope_type": "llama3",
"factor": args.rope_scaling_factor,
"high_freq_factor": args.rope_scaling_high_freq_factor,
"low_freq_factor": args.rope_scaling_low_freq_factor,
"original_max_position_embeddings": args.rope_scaling_original_max_position_embeddings,
}
model_config = LlamaConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_width,
intermediate_size=args.llama_intermediate_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_key_value_heads,
hidden_act="silu",
max_position_embeddings=args.max_context_width,
initializer_range=args.initializer_range,
rms_norm_eps=1e-5,
use_cache=False,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=rope_scaling,
rope_theta=args.rotary_emb_base,
)
elif "mistral" in args.model_type:
from transformers import MistralConfig
model_config = MistralConfig(
vocab_size=args.vocab_size, # 32000
hidden_size=args.hidden_width, # 4096
intermediate_size=args.intermediate_size, # 14336
num_hidden_layers=args.num_layers, # 32
num_attention_heads=args.num_heads, # 32
num_key_value_heads=args.num_key_value_heads, # 8
hidden_act="silu",
max_position_embeddings=args.max_context_width, # 4096 * 32
initializer_range=args.initializer_range, # 0.02
rms_norm_eps=1e-6,
use_cache=False,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=args.sliding_window, # 4096
attention_dropout=0.0,
)
elif "mixtral" in args.model_type:
from transformers import MixtralConfig
model_config = MixtralConfig(
vocab_size=args.vocab_size, # 32000,
hidden_size=args.hidden_width, # 4096,
intermediate_size=args.intermediate_size, # 14336,
num_hidden_layers=args.num_layers, # 32,
num_attention_heads=args.num_heads, # 32,
num_key_value_heads=args.num_key_value_heads, # 8,
hidden_act="silu",
max_position_embeddings=args.max_context_width, # 4096 * 32,
initializer_range=args.initializer_range, # 0.02,
rms_norm_eps=1e-5,
use_cache=False,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
sliding_window=args.sliding_window, # None,
attention_dropout=0.0,
num_experts_per_tok=args.num_experts_per_tok, # 2,
num_local_experts=args.num_local_experts, # 8,
output_router_logits=False,
router_aux_loss_coef=0.001,
)
else:
raise NotImplementedError
return model_config