in text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py [0:0]
def load_model_info(config: "PretrainedConfig") -> Any:
num_layers = config.num_hidden_layers
num_heads = config.num_attention_heads
head_dim = _get_head_dim(config)
num_kv_heads = config.num_key_value_heads
n_reps = num_heads // num_kv_heads
if config.model_type == "llama":
model_class = LlamaModel
elif config.model_type == "gemma":
model_class = GemmaModel
elif config.model_type == "mixtral":
model_class = MixtralModel
else:
raise ValueError(f"Unsupported model type {config.model_type}")
model_info = fetch_models.ModelInfo(
model_class=model_class,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
n_reps=n_reps,
)
return model_info