def config_name_to_class()

in optimum/tpu/modeling.py [0:0]


def config_name_to_class(pretrained_model_name_or_path: str):
    config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
    if config.model_type == "gemma":
        from .modeling_gemma import GemmaForCausalLM

        return GemmaForCausalLM
    if config.model_type == "llama":
        from .modeling_llama import LlamaForCausalLM

        return LlamaForCausalLM
    if config.model_type == "mistral":
        from .modeling_mistral import MistralForCausalLM

        return MistralForCausalLM
    return BaseAutoModelForCausalLM