def get_checkpoints_to_run()

in ablations/evaluation/launch_evals.py [0:0]


def get_checkpoints_to_run(s3_path: str, model_name: str, checkpoints: str, logging_dir: str, overwrite: bool = False,
                           after_date: Optional[str] = None):
    reference_date = parse_date(after_date)
    df = get_datafolder(s3_path)
    try:
        avail_checkpoints = [i for i in sorted(df.ls("", detail=False)) if i != "latest.txt"]
    except FileNotFoundError:
        logger.error(f"No checkpoints found in {s3_path}")
        avail_checkpoints = []
    logger.info(f"Found {len(avail_checkpoints)} checkpoints")
    selected_checkpoints = checkpoints.split(",") if checkpoints != "all" else avail_checkpoints
    not_found_checkpoints = [ckpt for ckpt in selected_checkpoints if ckpt not in avail_checkpoints]
    if len(not_found_checkpoints) > 0:
        raise ValueError(f"Checkpoints not found in \"{s3_path}\": {not_found_checkpoints}")

    if not overwrite:
        # remove completed checkpoints
        completed_checkpoints = [
            ckpt for ckpt in selected_checkpoints
            if checkpoint_exists(logging_dir, model_name, ckpt, reference_date)
        ]
        completed = len(completed_checkpoints)
        selected_checkpoints = list(set(selected_checkpoints) - set(completed_checkpoints))
        if completed:
            logger.info(f"Skipping {completed} already evaluated checkpoints.")
    return selected_checkpoints