def get_model_params()

in utils/utils.py [0:0]


def get_model_params(parser: argparse.ArgumentParser):
    parser.add_argument("--model_name", help="Model for training")
    parser.add_argument(
        "--hidden_size", type=int, help='Tansformer hidden size.', default=1024
    )
    parser.add_argument("--num_layers", type=int, help='Number of transformer layers.', default=24)
    parser.add_argument(
        "--seq_length", type=int, help='Maximum sequence length to process.', default=2048
    )
    parser.add_argument("--num_attention_heads", help='Number of transformer attention heads.',type=int, default=None)
    parser.add_argument("--vocab_size", type=int, help='Size of vocab before EOD or padding.', default=32000)
    parser.add_argument("--max_position_embeddings", type=int,help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.', default=4096)
    parser.add_argument("--add_bias_linear",help='Enable bias in the linear layers', action="store_true")
    parser.add_argument(
        "--use_flash_attn",
        action="store_true",
        help="Use FlashAttention implementation of attention.",
    )
    parser.add_argument(
        "--swiglu",
        action="store_true",
        help="Use gated linear units and SiLU activation instead of default gelu",
    )