def load_model()

in kyutai/run_eval.py [0:0]


def load_model(model_path):

    info = models.loaders.CheckpointInfo.from_hf_repo(model_path)

    mimi = info.get_mimi(device="cuda")
    tokenizer = info.get_text_tokenizer()
    lm = info.get_moshi(
        device="cuda",
        dtype=torch.bfloat16,
    )
    lm_gen = models.LMGen(lm, temp=0, temp_text=0.0)

    padding_token_id = info.raw_config.get("text_padding_token_id", 3)
    # Putting in some conservative defaults
    audio_silence_prefix_seconds = info.stt_config.get(
        "audio_silence_prefix_seconds", 1.0
    )
    audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)

    return (
        mimi,
        tokenizer,
        lm,
        lm_gen,
        padding_token_id,
        audio_silence_prefix_seconds,
        audio_delay_seconds,
    )