def get_trtllm_checkpoints()

in src/optimum/nvidia/hub.py [0:0]


def get_trtllm_checkpoints(model_id: str, device: str, dtype: str):
    if (
        workspace := Workspace.from_hub_cache(model_id, device)
    ).checkpoints_path.exists():
        return workspace.checkpoints_path

    return get_trtllm_artifact(model_id, [f"{device}/{dtype}/**/*.safetensors"])