def load_model_info()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py [0:0]


def load_model_info(config: "PretrainedConfig") -> Any:
    num_layers = config.num_hidden_layers
    num_heads = config.num_attention_heads
    head_dim = _get_head_dim(config)
    num_kv_heads = config.num_key_value_heads
    n_reps = num_heads // num_kv_heads
    if config.model_type == "llama":
        model_class = LlamaModel
    elif config.model_type == "gemma":
        model_class = GemmaModel
    elif config.model_type == "mixtral":
        model_class = MixtralModel
    else:
        raise ValueError(f"Unsupported model type {config.model_type}")
    model_info = fetch_models.ModelInfo(
        model_class=model_class,
        num_layers=num_layers,
        num_kv_heads=num_kv_heads,
        head_dim=head_dim,
        n_reps=n_reps,
    )
    return model_info