def get_model_config()

in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/train_utils.py [0:0]


def get_model_config(args):
    """Get model config."""
    if "gpt_neox" in args.model_type:
        from transformers import GPTNeoXConfig

        model_config = GPTNeoXConfig(
            vocab_size=args.vocab_size,
            hidden_size=args.hidden_width,
            num_hidden_layers=args.num_layers,
            num_attention_heads=args.num_heads,
            hidden_act="gelu",
            intermediate_size=4 * args.hidden_width,
            rotary_pct=args.rotary_pct,
            rotary_emb_base=args.rotary_emb_base,
            max_position_embeddings=args.max_context_width,
            layer_norm_eps=1e-05,
            initializer_range=args.initializer_range,
            use_cache=False,
            tie_word_embeddings=False,
            use_parallel_residual=True,
            attention_dropout=0.0,
            hidden_dropout=0.0,
        )
    elif "gpt2" in args.model_type:
        from transformers import GPT2Config

        model_config = GPT2Config(
            vocab_size=args.vocab_size,
            n_positions=args.max_context_width,
            n_embd=args.hidden_width,
            n_layer=args.num_layers,
            n_head=args.num_heads,
            n_inner=None,
            activation_function="gelu_new",
            resid_pdrop=args.resid_pdrop,
            embd_pdrop=args.embd_pdrop,
            attn_pdrop=args.attn_pdrop,
            layer_norm_epsilon=1e-05,
            initializer_range=args.initializer_range,
            summary_type="cls_index",
            summary_use_proj=True,
            summary_activation=None,
            summary_proj_to_labels=True,
            summary_first_dropout=args.summary_first_pdrop,
            use_cache=False,
            bos_token_id=50256,
            eos_token_id=50256,
            return_dict=True,
        )
    elif "llama_v2" in args.model_type:
        from transformers import LlamaConfig

        model_config = LlamaConfig(
            vocab_size=args.vocab_size,
            hidden_size=args.hidden_width,
            intermediate_size=args.llama_intermediate_size,
            num_hidden_layers=args.num_layers,
            num_attention_heads=args.num_heads,
            num_key_value_heads=args.num_key_value_heads,
            hidden_act="silu",
            max_position_embeddings=args.max_context_width,
            initializer_range=args.initializer_range,
            rms_norm_eps=1e-5,
            use_cache=False,
            pretraining_tp=1,
            tie_word_embeddings=False,
            rope_scaling=None,
            rope_theta=args.rotary_emb_base,
        )
    elif "llama_v3" in args.model_type:
        from transformers import LlamaConfig

        rope_scaling = None
        if args.rope_scaling_type == "llama3":
            if pversion.parse(transformers.__version__) < pversion.parse("4.44.2"):
                raise ValueError(
                    "Rope scaling type 'llama3' is only supported for transformers >= 4.44.2. "
                    "Please upgrade transformers or pass None to use the original RoPE implementation."
                )
            rope_scaling = {
                "rope_type": "llama3",
                "factor": args.rope_scaling_factor,
                "high_freq_factor": args.rope_scaling_high_freq_factor,
                "low_freq_factor": args.rope_scaling_low_freq_factor,
                "original_max_position_embeddings": args.rope_scaling_original_max_position_embeddings,
            }
        model_config = LlamaConfig(
            vocab_size=args.vocab_size,
            hidden_size=args.hidden_width,
            intermediate_size=args.llama_intermediate_size,
            num_hidden_layers=args.num_layers,
            num_attention_heads=args.num_heads,
            num_key_value_heads=args.num_key_value_heads,
            hidden_act="silu",
            max_position_embeddings=args.max_context_width,
            initializer_range=args.initializer_range,
            rms_norm_eps=1e-5,
            use_cache=False,
            pretraining_tp=1,
            tie_word_embeddings=False,
            rope_scaling=rope_scaling,
            rope_theta=args.rotary_emb_base,
        )
    elif "mistral" in args.model_type:
        from transformers import MistralConfig

        model_config = MistralConfig(
            vocab_size=args.vocab_size, # 32000
            hidden_size=args.hidden_width, # 4096
            intermediate_size=args.intermediate_size, # 14336
            num_hidden_layers=args.num_layers, # 32
            num_attention_heads=args.num_heads, # 32
            num_key_value_heads=args.num_key_value_heads, # 8
            hidden_act="silu",
            max_position_embeddings=args.max_context_width, # 4096 * 32
            initializer_range=args.initializer_range, # 0.02
            rms_norm_eps=1e-6,
            use_cache=False,
            pad_token_id=None,
            bos_token_id=1,
            eos_token_id=2,
            tie_word_embeddings=False,
            rope_theta=10000.0,
            sliding_window=args.sliding_window, # 4096
            attention_dropout=0.0,
        )
    elif "mixtral" in args.model_type:
        from transformers import MixtralConfig

        model_config = MixtralConfig(
            vocab_size=args.vocab_size, # 32000,
            hidden_size=args.hidden_width, # 4096,
            intermediate_size=args.intermediate_size, # 14336,
            num_hidden_layers=args.num_layers, # 32,
            num_attention_heads=args.num_heads, # 32,
            num_key_value_heads=args.num_key_value_heads, # 8,
            hidden_act="silu",
            max_position_embeddings=args.max_context_width, # 4096 * 32,
            initializer_range=args.initializer_range, # 0.02,
            rms_norm_eps=1e-5,
            use_cache=False,
            pad_token_id=None,
            bos_token_id=1,
            eos_token_id=2,
            tie_word_embeddings=False,
            rope_theta=1e6,
            sliding_window=args.sliding_window, # None,
            attention_dropout=0.0,
            num_experts_per_tok=args.num_experts_per_tok, # 2,
            num_local_experts=args.num_local_experts, # 8,
            output_router_logits=False,
            router_aux_loss_coef=0.001,
        )
    else:
        raise NotImplementedError
    return model_config